diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index b04a201c08d06..3e274368fb5bf 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -28,10 +28,17 @@ class CallbackRequest: :param full_filepath: File Path to use to run the callback :param msg: Additional Message that can be used for logging + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ - def __init__(self, full_filepath: str, msg: Optional[str] = None): + def __init__( + self, + full_filepath: str, + processor_subdir: Optional[str] = None, + msg: Optional[str] = None, + ): self.full_filepath = full_filepath + self.processor_subdir = processor_subdir self.msg = msg def __eq__(self, other): @@ -60,6 +67,7 @@ class TaskCallbackRequest(CallbackRequest): :param simple_task_instance: Simplified Task Instance representation :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging to determine failure/zombie + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ def __init__( @@ -67,9 +75,10 @@ def __init__( full_filepath: str, simple_task_instance: "SimpleTaskInstance", is_failure_callback: Optional[bool] = True, + processor_subdir: Optional[str] = None, msg: Optional[str] = None, ): - super().__init__(full_filepath=full_filepath, msg=msg) + super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.simple_task_instance = simple_task_instance self.is_failure_callback = is_failure_callback @@ -94,6 +103,7 @@ class DagCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param dag_id: DAG ID :param run_id: Run ID for the DagRun + :param processor_subdir: Directory used by Dag Processor when parsed the dag. :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging """ @@ -103,10 +113,11 @@ def __init__( full_filepath: str, dag_id: str, run_id: str, + processor_subdir: Optional[str], is_failure_callback: Optional[bool] = True, msg: Optional[str] = None, ): - super().__init__(full_filepath=full_filepath, msg=msg) + super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.dag_id = dag_id self.run_id = run_id self.is_failure_callback = is_failure_callback @@ -118,8 +129,15 @@ class SlaCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param dag_id: DAG ID + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ - def __init__(self, full_filepath: str, dag_id: str, msg: Optional[str] = None): - super().__init__(full_filepath, msg) + def __init__( + self, + full_filepath: str, + dag_id: str, + processor_subdir: Optional[str], + msg: Optional[str] = None, + ): + super().__init__(full_filepath, processor_subdir=processor_subdir, msg=msg) self.dag_id = dag_id diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 13b299d1fe674..4537b62d6b4ca 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -2121,6 +2121,14 @@ type: integer example: ~ default: "20" + - name: dag_stale_not_seen_duration + description: | + Only applicable if `[scheduler]standalone_dag_processor` is true. + Time in seconds after which dags, which were not updated by Dag Processor are deactivated. + version_added: 2.4.0 + type: integer + example: ~ + default: "600" - name: use_job_schedule description: | Turn off scheduler use of cron intervals by setting this to False. diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 119873d9f20a2..7cd116369ecc2 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -1077,6 +1077,10 @@ standalone_dag_processor = False # in database. Contains maximum number of callbacks that are fetched during a single loop. max_callbacks_per_loop = 20 +# Only applicable if `[scheduler]standalone_dag_processor` is true. +# Time in seconds after which dags, which were not updated by Dag Processor are deactivated. +dag_stale_not_seen_duration = 600 + # Turn off scheduler use of cron intervals by setting this to False. # DAGs submitted manually in the web UI or with trigger_dag will still run. use_job_schedule = True diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index ecd1f24434d23..fd1098899e7a1 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -31,7 +31,8 @@ from datetime import datetime, timedelta from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Union, cast from setproctitle import setproctitle from sqlalchemy.orm import Session @@ -57,9 +58,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks -if TYPE_CHECKING: - import pathlib - class DagParsingStat(NamedTuple): """Information on processing progress""" @@ -107,7 +105,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): def __init__( self, - dag_directory: str, + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, dag_ids: Optional[List[str]], @@ -116,7 +114,7 @@ def __init__( ): super().__init__() self._file_path_queue: List[str] = [] - self._dag_directory: str = dag_directory + self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout self._dag_ids = dag_ids @@ -205,7 +203,7 @@ def wait_until_finished(self) -> None: @staticmethod def _run_processor_manager( - dag_directory: str, + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, signal_conn: MultiprocessingConnection, @@ -368,7 +366,7 @@ class DagFileProcessorManager(LoggingMixin): def __init__( self, - dag_directory: Union[str, "pathlib.Path"], + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, dag_ids: Optional[List[str]], @@ -379,7 +377,6 @@ def __init__( super().__init__() self._file_paths: List[str] = [] self._file_path_queue: List[str] = [] - self._dag_directory = dag_directory self._max_runs = max_runs # signal_conn is None for dag_processor_standalone mode. self._direct_scheduler_conn = signal_conn @@ -387,6 +384,7 @@ def __init__( self._dag_ids = dag_ids self._async_mode = async_mode self._parsing_start_time: Optional[int] = None + self._dag_directory = dag_directory # Set the signal conn in to non-blocking mode, so that attempting to # send when the buffer is full errors, rather than hangs for-ever @@ -397,6 +395,7 @@ def __init__( if self._async_mode and self._direct_scheduler_conn is not None: os.set_blocking(self._direct_scheduler_conn.fileno(), False) + self.standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") self._parallelism = conf.getint('scheduler', 'parsing_processes') if ( conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite') @@ -498,11 +497,13 @@ def _deactivate_stale_dags(self, session=None): fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp) } to_deactivate = set() - dags_parsed = ( - session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time) - .filter(DagModel.is_active) - .all() + query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter( + DagModel.is_active ) + if self.standalone_dag_processor: + query = query.filter(DagModel.processor_subdir == self.get_dag_directory()) + dags_parsed = query.all() + for dag in dags_parsed: # The largest valid difference between a DagFileStat's last_finished_time and a DAG's # last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is @@ -540,7 +541,7 @@ def _run_parsing_loop(self): self._refresh_dag_dir() self.prepare_file_path_queue() max_callbacks_per_loop = conf.getint("scheduler", "max_callbacks_per_loop") - standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") + if self._async_mode: # If we're in async mode, we can start up straight away. If we're # in sync mode we need to be told to start a "loop" @@ -591,7 +592,7 @@ def _run_parsing_loop(self): self.waitables.pop(sentinel) self._processors.pop(processor.file_path) - if standalone_dag_processor: + if self.standalone_dag_processor: self._fetch_callbacks(max_callbacks_per_loop) self._deactivate_stale_dags() DagWarning.purge_inactive_dag_warnings() @@ -661,11 +662,12 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): """Fetches callbacks from database and add them to the internal queue for execution.""" self.log.debug("Fetching callbacks from the database.") with prohibit_commit(session) as guard: - query = ( - session.query(DbCallbackRequest) - .order_by(DbCallbackRequest.priority_weight.asc()) - .limit(max_callbacks) - ) + query = session.query(DbCallbackRequest) + if self.standalone_dag_processor: + query = query.filter( + DbCallbackRequest.processor_subdir == self.get_dag_directory(), + ) + query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks) callbacks = with_row_locks( query, of=DbCallbackRequest, session=session, **skip_locked(session=session) ).all() @@ -743,7 +745,10 @@ def _refresh_dag_dir(self): else: dag_filelocs.append(fileloc) - SerializedDagModel.remove_deleted_dags(dag_filelocs) + SerializedDagModel.remove_deleted_dags( + alive_dag_filelocs=dag_filelocs, + processor_subdir=self.get_dag_directory(), + ) DagModel.deactivate_deleted_dags(self._file_paths) from airflow.models.dagcode import DagCode @@ -913,6 +918,16 @@ def get_run_count(self, file_path): stat = self._file_stats.get(file_path) return stat.run_count if stat else 0 + def get_dag_directory(self) -> str: + """ + Returns the dag_director as a string. + :rtype: str + """ + if isinstance(self._dag_directory, Path): + return str(self._dag_directory.resolve()) + else: + return str(self._dag_directory) + def set_file_paths(self, new_file_paths): """ Update this with a new set of paths to DAG definition files. @@ -986,10 +1001,14 @@ def collect_results(self) -> None: self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) @staticmethod - def _create_process(file_path, pickle_dags, dag_ids, callback_requests): + def _create_process(file_path, pickle_dags, dag_ids, dag_directory, callback_requests): """Creates DagFileProcessorProcess instance.""" return DagFileProcessorProcess( - file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests + file_path=file_path, + pickle_dags=pickle_dags, + dag_ids=dag_ids, + dag_directory=dag_directory, + callback_requests=callback_requests, ) def start_new_processes(self): @@ -1002,7 +1021,11 @@ def start_new_processes(self): callback_to_execute_for_file = self._callback_to_execute[file_path] processor = self._create_process( - file_path, self._pickle_dags, self._dag_ids, callback_to_execute_for_file + file_path, + self._pickle_dags, + self._dag_ids, + self.get_dag_directory(), + callback_to_execute_for_file, ) del self._callback_to_execute[file_path] diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index b7dac828e2894..fa1eb46c299f8 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -76,12 +76,14 @@ def __init__( file_path: str, pickle_dags: bool, dag_ids: Optional[List[str]], + dag_directory: str, callback_requests: List[CallbackRequest], ): super().__init__() self._file_path = file_path self._pickle_dags = pickle_dags self._dag_ids = dag_ids + self._dag_directory = dag_directory self._callback_requests = callback_requests # The process that was launched to process the given . @@ -111,6 +113,7 @@ def _run_file_processor( pickle_dags: bool, dag_ids: Optional[List[str]], thread_name: str, + dag_directory: str, callback_requests: List[CallbackRequest], ) -> None: """ @@ -154,7 +157,11 @@ def _run_file_processor( threading.current_thread().name = thread_name log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) - dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) + dag_file_processor = DagFileProcessor( + dag_ids=dag_ids, + dag_directory=dag_directory, + log=log, + ) result: Tuple[int, int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, @@ -188,6 +195,7 @@ def start(self) -> None: self._pickle_dags, self._dag_ids, f"DagFileProcessor{self._instance_id}", + self._dag_directory, self._callback_requests, ), name=f"DagFileProcessor{self._instance_id}-Process", @@ -356,10 +364,11 @@ class DagFileProcessor(LoggingMixin): UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE') - def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger): + def __init__(self, dag_ids: Optional[List[str]], dag_directory: str, log: logging.Logger): super().__init__() self.dag_ids = dag_ids self._log = log + self._dag_directory = dag_directory self.dag_warnings: Set[Tuple[str, str]] = set() @provide_session @@ -766,7 +775,7 @@ def process_file( session.commit() # Save individual DAGs in the ORM - dagbag.sync_to_db(session) + dagbag.sync_to_db(processor_subdir=self._dag_directory, session=session) session.commit() if pickle_dags: diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index d8b8073674305..50dd18d2c2fe8 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -26,6 +26,7 @@ import warnings from collections import defaultdict from datetime import datetime, timedelta +from pathlib import Path from typing import TYPE_CHECKING, Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple from sqlalchemy import func, not_, or_, text @@ -141,6 +142,7 @@ def __init__( # How many seconds do we wait for tasks to heartbeat before mark them as zombies. self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") + self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration") self.do_pickle = do_pickle super().__init__(*args, **kwargs) @@ -685,6 +687,7 @@ def _process_executor_events(self, session: Session) -> int: full_filepath=ti.dag_model.fileloc, simple_task_instance=SimpleTaskInstance.from_ti(ti), msg=msg % (ti, state, ti.state, info), + processor_subdir=ti.dag_model.processor_subdir, ) self.executor.send_callback(request) else: @@ -708,7 +711,7 @@ def _execute(self) -> None: processor_timeout = timedelta(seconds=processor_timeout_seconds) if not self._standalone_dag_processor: self.processor_agent = DagFileProcessorAgent( - dag_directory=self.subdir, + dag_directory=Path(self.subdir), max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, dag_ids=[], @@ -834,6 +837,12 @@ def _run_scheduler_loop(self) -> None: ) timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags) + if self._standalone_dag_processor: + timers.call_regular_interval( + conf.getfloat('scheduler', 'deactivate_stale_dags_interval', fallback=60.0), + self._cleanup_stale_dags, + ) + for loop_count in itertools.count(start=1): with Stats.timer() as timer: @@ -1260,6 +1269,7 @@ def _schedule_dag_run( dag_id=dag.dag_id, run_id=dag_run.run_id, is_failure_callback=True, + processor_subdir=dag_model.processor_subdir, msg='timed_out', ) @@ -1322,7 +1332,12 @@ def _send_sla_callbacks_to_processor(self, dag: DAG) -> None: self.log.debug("Skipping SLA check for %s because DAG is not scheduled", dag) return - request = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id) + dag_model = DagModel.get_dagmodel(dag.dag_id) + request = SlaCallbackRequest( + full_filepath=dag.fileloc, + dag_id=dag.dag_id, + processor_subdir=dag_model.processor_subdir, + ) self.executor.send_callback(request) @provide_session @@ -1485,11 +1500,11 @@ def _find_zombies(self, session: Session) -> None: zombie_message_details = self._generate_zombie_message_details(ti) request = TaskCallbackRequest( full_filepath=file_loc, + processor_subdir=ti.dag_model.processor_subdir, simple_task_instance=SimpleTaskInstance.from_ti(ti), msg=str(zombie_message_details), ) - - self.log.error("Detected zombie job: %s", request.msg) + self.log.error("Detected zombie job: %s", request) self.executor.send_callback(request) Stats.incr('zombies_killed') @@ -1509,3 +1524,28 @@ def _generate_zombie_message_details(ti: TaskInstance): zombie_message_details["External Executor Id"] = ti.external_executor_id return zombie_message_details + + @provide_session + def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None: + """ + Find all dags that were not updated by Dag Processor recently and mark them as inactive. + + In case one of DagProcessors is stopped (in case there are multiple of them + for different dag folders), it's dags are never marked as inactive. + Also remove dags from SerializedDag table. + Executed on schedule only if [scheduler]standalone_dag_processor is True. + """ + self.log.debug("Checking dags not parsed within last %s seconds.", self._dag_stale_not_seen_duration) + limit_lpt = timezone.utcnow() - timedelta(seconds=self._dag_stale_not_seen_duration) + stale_dags = ( + session.query(DagModel).filter(DagModel.is_active, DagModel.last_parsed_time < limit_lpt).all() + ) + if not stale_dags: + self.log.debug("Not stale dags found.") + return + + self.log.info("Found (%d) stales dags not parsed after %s.", len(stale_dags), limit_lpt) + for dag in stale_dags: + dag.is_active = False + SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session) + session.flush() diff --git a/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py new file mode 100644 index 0000000000000..ae16a3fffb59a --- /dev/null +++ b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add processor_subdir column to DagModel, SerializedDagModel and CallbackRequest tables. + +Revision ID: ecb43d2a1842 +Revises: 1486deb605b4 +Create Date: 2022-08-26 11:30:11.249580 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'ecb43d2a1842' +down_revision = '1486deb605b4' +branch_labels = None +depends_on = None +airflow_version = '2.4.0' + + +def upgrade(): + """Apply add processor_subdir to DagModel and SerializedDagModel""" + conn = op.get_bind() + + with op.batch_alter_table('dag') as batch_op: + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True)) + + with op.batch_alter_table('serialized_dag') as batch_op: + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True)) + + with op.batch_alter_table('callback_request') as batch_op: + batch_op.drop_column('dag_directory') + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True)) + + +def downgrade(): + """Unapply Add processor_subdir to DagModel and SerializedDagModel""" + conn = op.get_bind() + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.drop_column('processor_subdir') + + with op.batch_alter_table('serialized_dag', schema=None) as batch_op: + batch_op.drop_column('processor_subdir') + + with op.batch_alter_table('callback_request') as batch_op: + batch_op.drop_column('processor_subdir') + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column('dag_directory', sa.Text(length=1000), nullable=True)) + else: + batch_op.add_column(sa.Column('dag_directory', sa.String(length=1000), nullable=True)) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b3b10b6d66b4c..c7336b8aaa549 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2532,18 +2532,27 @@ def create_dagrun( @classmethod @provide_session - def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): + def bulk_sync_to_db( + cls, + dags: Collection["DAG"], + session=NEW_SESSION, + ): """This method is deprecated in favor of bulk_write_to_db""" warnings.warn( "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", RemovedInAirflow3Warning, stacklevel=2, ) - return cls.bulk_write_to_db(dags, session) + return cls.bulk_write_to_db(dags=dags, session=session) @classmethod @provide_session - def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): + def bulk_write_to_db( + cls, + dags: Collection["DAG"], + processor_subdir: Optional[str] = None, + session=NEW_SESSION, + ): """ Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including calculated fields. @@ -2624,6 +2633,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks) orm_dag.schedule_interval = dag.schedule_interval orm_dag.timetable_description = dag.timetable.description + orm_dag.processor_subdir = processor_subdir run: Optional[DagRun] = most_recent_runs.get(dag.dag_id) if run is None: @@ -2729,10 +2739,10 @@ class InletRef(NamedTuple): session.flush() for dag in dags: - cls.bulk_write_to_db(dag.subdags, session=session) + cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) @provide_session - def sync_to_db(self, session=NEW_SESSION): + def sync_to_db(self, processor_subdir: Optional[str] = None, session=NEW_SESSION): """ Save attributes about this DAG to the DB. Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a @@ -2740,7 +2750,7 @@ def sync_to_db(self, session=NEW_SESSION): :return: None """ - self.bulk_write_to_db([self], session) + self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) def get_default_view(self): """This is only there for backward compatible jinja2 templates""" @@ -2977,6 +2987,8 @@ class DagModel(Base): # packaged DAG, it will point to the subpath of the DAG within the # associated zip. fileloc = Column(String(2000)) + # The base directory used by Dag Processor that parsed this dag. + processor_subdir = Column(String(2000), nullable=True) # String representing the owners owners = Column(String(2000)) # Description of the dag diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 3183f8a8f1dae..3f8d6e57ca38e 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -575,7 +575,7 @@ def dagbag_report(self): return report @provide_session - def sync_to_db(self, session: Session = None): + def sync_to_db(self, processor_subdir: Optional[str] = None, session: Session = None): """Save attributes about list of DAG to the DB.""" # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag from airflow.models.dag import DAG @@ -622,7 +622,9 @@ def _serialize_dag_capturing_errors(dag, session): for dag in self.dags.values(): serialize_errors.extend(_serialize_dag_capturing_errors(dag, session)) - DAG.bulk_write_to_db(self.dags.values(), session=session) + DAG.bulk_write_to_db( + self.dags.values(), processor_subdir=processor_subdir, session=session + ) except OperationalError: session.rollback() raise diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 41c9729282b99..65f92fd8c8838 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -572,11 +572,15 @@ def recalculate(self) -> "_UnfinishedStates": if execute_callbacks: dag.handle_callback(self, success=False, reason='task_failure', session=session) elif dag.has_on_failure_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, + processor_subdir=dag_model.processor_subdir, msg='task_failure', ) @@ -587,11 +591,15 @@ def recalculate(self) -> "_UnfinishedStates": if execute_callbacks: dag.handle_callback(self, success=True, reason='success', session=session) elif dag.has_on_success_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=False, + processor_subdir=dag_model.processor_subdir, msg='success', ) @@ -602,11 +610,15 @@ def recalculate(self) -> "_UnfinishedStates": if execute_callbacks: dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) elif dag.has_on_failure_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, + processor_subdir=dag_model.processor_subdir, msg='all_tasks_deadlocked', ) diff --git a/airflow/models/db_callback_request.py b/airflow/models/db_callback_request.py index 4fdd36a71be4b..1f6ee4cd851fe 100644 --- a/airflow/models/db_callback_request.py +++ b/airflow/models/db_callback_request.py @@ -36,11 +36,12 @@ class DbCallbackRequest(Base): priority_weight = Column(Integer(), nullable=False) callback_data = Column(ExtendedJSON, nullable=False) callback_type = Column(String(20), nullable=False) - dag_directory = Column(String(1000), nullable=True) + processor_subdir = Column(String(2000), nullable=True) def __init__(self, priority_weight: int, callback: CallbackRequest): self.created_at = timezone.utcnow() self.priority_weight = priority_weight + self.processor_subdir = callback.processor_subdir self.callback_data = callback.to_json() self.callback_type = callback.__class__.__name__ diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 9114af987c4b5..baf3691f50aa6 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional import sqlalchemy_jsonfield -from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_ +from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_ from sqlalchemy.orm import Session, backref, foreign, relationship from sqlalchemy.sql.expression import func, literal @@ -72,6 +72,7 @@ class SerializedDagModel(Base): _data_compressed = Column('data_compressed', LargeBinary, nullable=True) last_updated = Column(UtcDateTime, nullable=False) dag_hash = Column(String(32), nullable=False) + processor_subdir = Column(String(2000), nullable=True) __table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),) @@ -92,11 +93,12 @@ class SerializedDagModel(Base): load_op_links = True - def __init__(self, dag: DAG): + def __init__(self, dag: DAG, processor_subdir: Optional[str] = None): self.dag_id = dag.dag_id self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() + self.processor_subdir = processor_subdir dag_data = SerializedDAG.to_dict(dag) dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") @@ -119,7 +121,13 @@ def __repr__(self): @classmethod @provide_session - def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: Session = None) -> bool: + def write_dag( + cls, + dag: DAG, + min_update_interval: Optional[int] = None, + processor_subdir: Optional[str] = None, + session: Session = None, + ) -> bool: """Serializes a DAG and writes it into database. If the record already exists, it checks if the Serialized DAG changed or not. If it is changed, it updates the record, ignores otherwise. @@ -151,10 +159,16 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: return False log.debug("Checking if DAG (%s) changed", dag.dag_id) - new_serialized_dag = cls(dag) - serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar() + new_serialized_dag = cls(dag, processor_subdir) + serialized_dag_db = ( + session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first() + ) - if serialized_dag_hash_from_db == new_serialized_dag.dag_hash: + if ( + serialized_dag_db is not None + and serialized_dag_db.dag_hash == new_serialized_dag.dag_hash + and serialized_dag_db.processor_subdir == new_serialized_dag.processor_subdir + ): log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id) return False @@ -222,7 +236,9 @@ def remove_dag(cls, dag_id: str, session: Session = None): @classmethod @provide_session - def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): + def remove_deleted_dags( + cls, alive_dag_filelocs: List[str], processor_subdir: Optional[str] = None, session=None + ): """Deletes DAGs not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs @@ -236,7 +252,14 @@ def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): session.execute( cls.__table__.delete().where( - and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)) + and_( + cls.fileloc_hash.notin_(alive_fileloc_hashes), + cls.fileloc.notin_(alive_dag_filelocs), + or_( + cls.processor_subdir is None, + cls.processor_subdir == processor_subdir, + ), + ) ) ) @@ -281,7 +304,7 @@ def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagMod @staticmethod @provide_session - def bulk_sync_to_db(dags: List[DAG], session: Session = None): + def bulk_sync_to_db(dags: List[DAG], processor_subdir: Optional[str] = None, session: Session = None): """ Saves DAGs as Serialized DAG objects in the database. Each DAG is saved in a separate database query. @@ -293,7 +316,10 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None): for dag in dags: if not dag.is_subdag: SerializedDagModel.write_dag( - dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session + dag=dag, + min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, + processor_subdir=processor_subdir, + session=session, ) @classmethod diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 28b5928738d8b..cd4a608fe345b 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -27,7 +27,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=================================+===================+===================+==============================================================+ -| ``1486deb605b4`` (head) | ``f4ff391becb5`` | ``2.4.0`` | add dag_owner_attributes table | +| ``ecb43d2a1842`` (head) | ``1486deb605b4`` | ``2.4.0`` | Add processor_subdir column to DagModel, SerializedDagModel | +| | | | and CallbackRequest tables. | ++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ +| ``1486deb605b4`` | ``f4ff391becb5`` | ``2.4.0`` | add dag_owner_attributes table | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | ``f4ff391becb5`` | ``0038cd0c28b4`` | ``2.4.0`` | Remove smart sensors | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index 3764f19c4c4f3..8819571f57714 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -46,6 +46,7 @@ class TestCallbackRequest: TaskCallbackRequest( full_filepath="filepath", simple_task_instance=SimpleTaskInstance.from_ti(ti=TI), + processor_subdir='/test_dir', is_failure_callback=True, ), TaskCallbackRequest, @@ -55,11 +56,19 @@ class TestCallbackRequest: full_filepath="filepath", dag_id="fake_dag", run_id="fake_run", + processor_subdir='/test_dir', is_failure_callback=False, ), DagCallbackRequest, ), - (SlaCallbackRequest(full_filepath="filepath", dag_id="fake_dag"), SlaCallbackRequest), + ( + SlaCallbackRequest( + full_filepath="filepath", + dag_id="fake_dag", + processor_subdir='/test_dir', + ), + SlaCallbackRequest, + ), ] ) def test_from_json(self, input, request_class): @@ -76,6 +85,7 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create input = TaskCallbackRequest( full_filepath="filepath", simple_task_instance=SimpleTaskInstance.from_ti(ti), + processor_subdir='/test_dir', is_failure_callback=True, ) json_str = input.to_json() diff --git a/tests/conftest.py b/tests/conftest.py index 2d0b90d821588..b974bf25a097c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -518,11 +518,13 @@ def __exit__(self, type, value, traceback): return dag.clear(session=self.session) - dag.sync_to_db(self.session) + dag.sync_to_db(processor_subdir=self.processor_subdir, session=self.session) self.dag_model = self.session.query(DagModel).get(dag.dag_id) if self.want_serialized: - self.serialized_model = SerializedDagModel(dag) + self.serialized_model = SerializedDagModel( + dag, processor_subdir=self.dag_model.processor_subdir + ) self.session.merge(self.serialized_model) serialized_dag = self._serialized_dag() self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag) @@ -578,7 +580,13 @@ def create_dagrun_after(self, dagrun, **kwargs): ) def __call__( - self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs + self, + dag_id='test_dag', + serialized=want_serialized, + fileloc=None, + processor_subdir=None, + session=None, + **kwargs, ): from airflow import settings from airflow.models import DAG @@ -606,6 +614,7 @@ def __call__( self.dag = DAG(dag_id, **self.kwargs) self.dag.fileloc = fileloc or request.module.__file__ self.want_serialized = serialized + self.processor_subdir = processor_subdir return self diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index f88c27dafd3fe..70208312e9e91 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -66,8 +66,8 @@ class FakeDagFileProcessorRunner(DagFileProcessorProcess): # This fake processor will return the zombies it received in constructor # as its processing result w/o actually parsing anything. - def __init__(self, file_path, pickle_dags, dag_ids, callbacks): - super().__init__(file_path, pickle_dags, dag_ids, callbacks) + def __init__(self, file_path, pickle_dags, dag_ids, dag_directory, callbacks): + super().__init__(file_path, pickle_dags, dag_ids, dag_directory, callbacks) # We need a "real" selectable handle for waitable_handle to work readable, writable = multiprocessing.Pipe(duplex=False) writable.send('abc') @@ -95,11 +95,12 @@ def result(self): return self._result @staticmethod - def _create_process(file_path, callback_requests, dag_ids, pickle_dags): + def _create_process(file_path, callback_requests, dag_ids, dag_directory, pickle_dags): return FakeDagFileProcessorRunner( file_path, pickle_dags, dag_ids, + dag_directory, callback_requests, ) @@ -504,7 +505,6 @@ def test_deactivate_stale_dags(self): ) assert serialized_dag_count == 1 - manager._file_stats[test_dag_path] = stat manager._deactivate_stale_dags() active_dag_count = ( @@ -521,6 +521,62 @@ def test_deactivate_stale_dags(self): ) assert serialized_dag_count == 0 + @conf_vars( + { + ('core', 'load_examples'): 'False', + ('scheduler', 'standalone_dag_processor'): 'True', + } + ) + def test_deactivate_stale_dags_standalone_mode(self): + """ + Ensure only dags from current dag_directory are updated + """ + dag_directory = 'directory' + manager = DagFileProcessorManager( + dag_directory=dag_directory, + max_runs=1, + processor_timeout=timedelta(minutes=10), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ) + + test_dag_path = str(TEST_DAG_FOLDER / 'test_example_bash_operator.py') + dagbag = DagBag(test_dag_path, read_dags_from_db=False) + other_test_dag_path = str(TEST_DAG_FOLDER / 'test_scheduler_dags.py') + other_dagbag = DagBag(other_test_dag_path, read_dags_from_db=False) + + with create_session() as session: + # Add stale DAG to the DB + dag = dagbag.get_dag('test_example_bash_operator') + dag.last_parsed_time = timezone.utcnow() + dag.sync_to_db(processor_subdir=dag_directory) + + # Add stale DAG to the DB + other_dag = other_dagbag.get_dag('test_start_date_scheduling') + other_dag.last_parsed_time = timezone.utcnow() + other_dag.sync_to_db(processor_subdir='other') + + # Add DAG to the file_parsing_stats + stat = DagFileStat( + num_dags=1, + import_errors=0, + last_finish_time=timezone.utcnow() + timedelta(hours=1), + last_duration=1, + run_count=1, + ) + manager._file_paths = [test_dag_path] + manager._file_stats[test_dag_path] = stat + + active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + assert active_dag_count == 2 + + manager._deactivate_stale_dags() + + active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + assert active_dag_count == 1 + @mock.patch( "airflow.dag_processing.processor.DagFileProcessorProcess.waitable_handle", new_callable=PropertyMock ) @@ -539,7 +595,13 @@ def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid, mock_waitable async_mode=True, ) - processor = DagFileProcessorProcess('abc.txt', False, [], []) + processor = DagFileProcessorProcess( + file_path='abc.txt', + pickle_dags=False, + dag_ids=[], + dag_directory=TEST_DAG_FOLDER, + callback_requests=[], + ) processor._start_time = timezone.make_aware(datetime.min) manager._processors = {'abc.txt': processor} manager.waitables[3] = processor @@ -554,7 +616,7 @@ def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid, mock_waitable def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid): mock_pid.return_value = 1234 manager = DagFileProcessorManager( - dag_directory='directory', + dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=timedelta(seconds=5), signal_conn=MagicMock(), @@ -563,7 +625,13 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p async_mode=True, ) - processor = DagFileProcessorProcess('abc.txt', False, [], []) + processor = DagFileProcessorProcess( + file_path='abc.txt', + pickle_dags=False, + dag_ids=[], + dag_directory=str(TEST_DAG_FOLDER), + callback_requests=[], + ) processor._start_time = timezone.make_aware(datetime.max) manager._processors = {'abc.txt': processor} manager._kill_timed_out_processors() @@ -757,17 +825,20 @@ def test_fetch_callbacks_from_database(self, tmpdir): dag_id="test_start_date_scheduling", full_filepath=str(dag_filepath), is_failure_callback=True, + processor_subdir=str(tmpdir), run_id='123', ) callback2 = DagCallbackRequest( dag_id="test_start_date_scheduling", full_filepath=str(dag_filepath), is_failure_callback=True, + processor_subdir=str(tmpdir), run_id='456', ) callback3 = SlaCallbackRequest( dag_id="test_start_date_scheduling", full_filepath=str(dag_filepath), + processor_subdir=str(tmpdir), ) with create_session() as session: @@ -777,7 +848,7 @@ def test_fetch_callbacks_from_database(self, tmpdir): child_pipe, parent_pipe = multiprocessing.Pipe() manager = DagFileProcessorManager( - dag_directory=tmpdir, + dag_directory=str(tmpdir), max_runs=1, processor_timeout=timedelta(days=365), signal_conn=child_pipe, @@ -790,6 +861,50 @@ def test_fetch_callbacks_from_database(self, tmpdir): self.run_processor_manager_one_loop(manager, parent_pipe) assert session.query(DbCallbackRequest).count() == 0 + @conf_vars( + { + ('core', 'load_examples'): 'False', + ('scheduler', 'standalone_dag_processor'): 'True', + } + ) + def test_fetch_callbacks_for_current_dag_directory_only(self, tmpdir): + """Test DagFileProcessorManager._fetch_callbacks method""" + dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" + + callback1 = DagCallbackRequest( + dag_id="test_start_date_scheduling", + full_filepath=str(dag_filepath), + is_failure_callback=True, + processor_subdir=str(tmpdir), + run_id='123', + ) + callback2 = DagCallbackRequest( + dag_id="test_start_date_scheduling", + full_filepath=str(dag_filepath), + is_failure_callback=True, + processor_subdir="/some/other/dir/", + run_id='456', + ) + + with create_session() as session: + session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) + session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) + + child_pipe, parent_pipe = multiprocessing.Pipe() + manager = DagFileProcessorManager( + dag_directory=tmpdir, + max_runs=1, + processor_timeout=timedelta(days=365), + signal_conn=child_pipe, + dag_ids=[], + pickle_dags=False, + async_mode=False, + ) + + with create_session() as session: + self.run_processor_manager_one_loop(manager, parent_pipe) + assert session.query(DbCallbackRequest).count() == 1 + @conf_vars( { ('scheduler', 'standalone_dag_processor'): 'True', @@ -808,12 +923,13 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmpdir): full_filepath=str(dag_filepath), is_failure_callback=True, run_id=str(i), + processor_subdir=str(tmpdir), ) session.add(DbCallbackRequest(callback=callback, priority_weight=i)) child_pipe, parent_pipe = multiprocessing.Pipe() manager = DagFileProcessorManager( - dag_directory=tmpdir, + dag_directory=str(tmpdir), max_runs=1, processor_timeout=timedelta(days=365), signal_conn=child_pipe, @@ -844,6 +960,7 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmpdir): dag_id="test_start_date_scheduling", full_filepath=str(dag_filepath), is_failure_callback=True, + processor_subdir=str(tmpdir), run_id='123', ) session.add(DbCallbackRequest(callback=callback, priority_weight=10)) @@ -884,6 +1001,7 @@ def test_callback_queue(self, tmpdir): dag_id="dag1", run_id="run1", is_failure_callback=False, + processor_subdir=tmpdir, msg=None, ) dag1_req2 = DagCallbackRequest( @@ -891,16 +1009,26 @@ def test_callback_queue(self, tmpdir): dag_id="dag1", run_id="run1", is_failure_callback=False, + processor_subdir=tmpdir, msg=None, ) - dag1_sla1 = SlaCallbackRequest(full_filepath="/green_eggs/ham/file1.py", dag_id="dag1") - dag1_sla2 = SlaCallbackRequest(full_filepath="/green_eggs/ham/file1.py", dag_id="dag1") + dag1_sla1 = SlaCallbackRequest( + full_filepath="/green_eggs/ham/file1.py", + dag_id="dag1", + processor_subdir=tmpdir, + ) + dag1_sla2 = SlaCallbackRequest( + full_filepath="/green_eggs/ham/file1.py", + dag_id="dag1", + processor_subdir=tmpdir, + ) dag2_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file2.py", dag_id="dag2", run_id="run1", is_failure_callback=False, + processor_subdir=tmpdir, msg=None, ) @@ -946,10 +1074,6 @@ def tearDown(self): for mod in remove_list: del sys.modules[mod] - @staticmethod - def _processor_factory(file_path, zombies, dag_ids, pickle_dags): - return DagFileProcessorProcess(file_path, pickle_dags, dag_ids, zombies) - def test_reload_module(self): """ Configure the context to have logging.logging_config_class set to a fake logging diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index d007851092f29..b0e09d0417df4 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -97,8 +97,10 @@ def teardown_method(self) -> None: self.scheduler_job = None self.clean_db() - def _process_file(self, file_path, session): - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + def _process_file(self, file_path, dag_directory, session): + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=str(dag_directory), log=mock.MagicMock() + ) dag_file_processor.process_file(file_path, [], False, session) @@ -124,7 +126,9 @@ def test_dag_file_processor_sla_miss_callback(self, create_dummy_dag): session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.manage_slas(dag=dag, session=session) assert sla_callback.called @@ -153,7 +157,9 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self, create_dummy_dag session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.manage_slas(dag=dag, session=session) sla_callback.assert_not_called() @@ -192,7 +198,9 @@ def test_dag_file_processor_sla_miss_callback_sent_notification(self, create_dum ) # Now call manage_slas and see if the sla_miss callback gets called - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.manage_slas(dag=dag, session=session) sla_callback.assert_not_called() @@ -220,7 +228,9 @@ def test_dag_file_processor_sla_miss_doesnot_raise_integrity_error(self, mock_st session.merge(ti) session.flush() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.manage_slas(dag=dag, session=session) sla_miss_count = ( session.query(SlaMiss) @@ -264,7 +274,7 @@ def test_dag_file_processor_sla_miss_callback_exception(self, mock_stats_incr, c # Now call manage_slas and see if the sla_miss callback gets called mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) assert sla_callback.called mock_log.exception.assert_called_once_with( @@ -294,7 +304,9 @@ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks( session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.manage_slas(dag=dag, session=session) @@ -333,7 +345,7 @@ def test_dag_file_processor_sla_miss_email_exception( session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) mock_log.exception.assert_called_once_with( @@ -364,13 +376,15 @@ def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag): ) mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) @patch.object(TaskInstance, 'handle_failure') def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) with create_session() as session: session.query(TaskInstance).delete() dag = dagbag.get_dag('example_branch_operator') @@ -401,7 +415,9 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): @patch.object(TaskInstance, 'handle_failure') def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) with create_session() as session: session.query(TaskInstance).delete() dag = dagbag.get_dag('example_branch_operator') @@ -431,7 +447,9 @@ def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, def test_failure_callbacks_should_not_drop_hostname(self): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag_file_processor.UNIT_TEST_MODE = False with create_session() as session: @@ -462,7 +480,9 @@ def test_process_file_should_failure_callback(self, monkeypatch, tmp_path, get_t callback_file = tmp_path.joinpath("callback.txt") callback_file.touch() monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor( + dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock() + ) dag = get_test_dag('test_on_failure_callback') task = dag.get_task(task_id='test_on_failure_callback_task') @@ -496,7 +516,7 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self, tmpd unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) with create_session() as session: - self._process_file(unparseable_filename, session) + self._process_file(unparseable_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -513,7 +533,7 @@ def test_add_unparseable_zip_file_creates_import_error(self, tmpdir): zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS) with create_session() as session: - self._process_file(zip_filename, session) + self._process_file(zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -530,7 +550,7 @@ def test_dag_model_has_import_error_is_true_when_import_error_exists(self, tmpdi for line in main_dag: next_dag.write(line) # first we parse the dag - self._process_file(temp_dagfile, session) + self._process_file(temp_dagfile, dag_directory=tmpdir, session=session) # assert DagModel.has_import_errors is false dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() assert not dm.has_import_errors @@ -538,7 +558,7 @@ def test_dag_model_has_import_error_is_true_when_import_error_exists(self, tmpdi with open(temp_dagfile, 'a') as file: file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - self._process_file(temp_dagfile, session) + self._process_file(temp_dagfile, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -555,7 +575,7 @@ def test_no_import_errors_with_parseable_dag(self, tmpdir): parseable_file.writelines(PARSEABLE_DAG_FILE_CONTENTS) with create_session() as session: - self._process_file(parseable_filename, session) + self._process_file(parseable_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 0 @@ -568,7 +588,7 @@ def test_no_import_errors_with_parseable_dag_in_zip(self, tmpdir): zip_file.writestr(TEMP_DAG_FILENAME, PARSEABLE_DAG_FILE_CONTENTS) with create_session() as session: - self._process_file(zip_filename, session) + self._process_file(zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 0 @@ -583,14 +603,14 @@ def test_new_import_error_replaces_old(self, tmpdir): with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) session = settings.Session() - self._process_file(unparseable_filename, session) + self._process_file(unparseable_filename, dag_directory=tmpdir, session=session) # Generate replacement import error (the error will be on the second line now) with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines( PARSEABLE_DAG_FILE_CONTENTS + os.linesep + UNPARSEABLE_DAG_FILE_CONTENTS ) - self._process_file(unparseable_filename, session) + self._process_file(unparseable_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() @@ -612,7 +632,7 @@ def test_import_error_record_is_updated_not_deleted_and_recreated(self, tmpdir): with open(filename_to_parse, 'w') as file_to_parse: file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) session = settings.Session() - self._process_file(filename_to_parse, session) + self._process_file(filename_to_parse, dag_directory=tmpdir, session=session) import_error_1 = ( session.query(errors.ImportError).filter(errors.ImportError.filename == filename_to_parse).one() @@ -620,7 +640,7 @@ def test_import_error_record_is_updated_not_deleted_and_recreated(self, tmpdir): # process the file multiple times for _ in range(10): - self._process_file(filename_to_parse, session) + self._process_file(filename_to_parse, dag_directory=tmpdir, session=session) import_error_2 = ( session.query(errors.ImportError).filter(errors.ImportError.filename == filename_to_parse).one() @@ -636,12 +656,12 @@ def test_remove_error_clears_import_error(self, tmpdir): with open(filename_to_parse, 'w') as file_to_parse: file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) session = settings.Session() - self._process_file(filename_to_parse, session) + self._process_file(filename_to_parse, dag_directory=tmpdir, session=session) # Remove the import error from the file with open(filename_to_parse, 'w') as file_to_parse: file_to_parse.writelines(PARSEABLE_DAG_FILE_CONTENTS) - self._process_file(filename_to_parse, session) + self._process_file(filename_to_parse, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() @@ -656,7 +676,7 @@ def test_remove_error_clears_import_error_zip(self, tmpdir): zip_filename = os.path.join(tmpdir, "test_zip.zip") with ZipFile(zip_filename, "w") as zip_file: zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS) - self._process_file(zip_filename, session) + self._process_file(zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -664,7 +684,7 @@ def test_remove_error_clears_import_error_zip(self, tmpdir): # Remove the import error from the file with ZipFile(zip_filename, "w") as zip_file: zip_file.writestr(TEMP_DAG_FILENAME, 'import os # airflow DAG') - self._process_file(zip_filename, session) + self._process_file(zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 0 @@ -677,7 +697,7 @@ def test_import_error_tracebacks(self, tmpdir): unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) with create_session() as session: - self._process_file(unparseable_filename, session) + self._process_file(unparseable_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -703,7 +723,7 @@ def test_import_error_traceback_depth(self, tmpdir): unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) with create_session() as session: - self._process_file(unparseable_filename, session) + self._process_file(unparseable_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -726,7 +746,7 @@ def test_import_error_tracebacks_zip(self, tmpdir): invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) with create_session() as session: - self._process_file(invalid_zip_filename, session) + self._process_file(invalid_zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -753,7 +773,7 @@ def test_import_error_tracebacks_zip_depth(self, tmpdir): invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS) with create_session() as session: - self._process_file(invalid_zip_filename, session) + self._process_file(invalid_zip_filename, dag_directory=tmpdir, session=session) import_errors = session.query(errors.ImportError).all() assert len(import_errors) == 1 @@ -779,7 +799,7 @@ def per_test(self): def test_error_when_waiting_in_async_mode(self, tmp_path): self.processor_agent = DagFileProcessorAgent( - dag_directory=str(tmp_path), + dag_directory=tmp_path, max_runs=1, processor_timeout=datetime.timedelta(1), dag_ids=[], @@ -792,7 +812,7 @@ def test_error_when_waiting_in_async_mode(self, tmp_path): def test_default_multiprocessing_behaviour(self, tmp_path): self.processor_agent = DagFileProcessorAgent( - dag_directory=str(tmp_path), + dag_directory=tmp_path, max_runs=1, processor_timeout=datetime.timedelta(1), dag_ids=[], @@ -806,7 +826,7 @@ def test_default_multiprocessing_behaviour(self, tmp_path): @conf_vars({("core", "mp_start_method"): "spawn"}) def test_spawn_multiprocessing_behaviour(self, tmp_path): self.processor_agent = DagFileProcessorAgent( - dag_directory=str(tmp_path), + dag_directory=tmp_path, max_runs=1, processor_timeout=datetime.timedelta(1), dag_ids=[], diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index b299b852c2724..c46f846b3e3c8 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -136,7 +136,6 @@ def set_instance_attrs(self, dagbag) -> Generator: # Speed up some tests by not running the tasks, just look at what we # enqueue! self.null_exec: Optional[MockExecutor] = MockExecutor() - # Since we don't want to store the code for the DAG defined in this file with patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags'), patch( 'airflow.models.dag.DagCode.bulk_sync_to_db' @@ -343,6 +342,7 @@ def test_process_executor_events_with_callback(self, mock_stats_incr, mock_task_ mock_task_callback.assert_called_once_with( full_filepath=dag.fileloc, simple_task_instance=mock.ANY, + processor_subdir=None, msg='Executor reports task instance ' ' ' 'finished (failed) although the task says its queued. (Info: None) ' @@ -1610,6 +1610,7 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', start_date=DEFAULT_DATE, max_active_runs=1, + processor_subdir=TEST_DAG_FOLDER, dagrun_timeout=datetime.timedelta(seconds=60), ) as dag: EmptyOperator(task_id='dummy') @@ -1657,6 +1658,7 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): dag_id=dr.dag_id, is_failure_callback=True, run_id=dr.run_id, + processor_subdir=TEST_DAG_FOLDER, msg="timed_out", ) @@ -1674,6 +1676,7 @@ def test_dagrun_timeout_fails_run(self, dag_maker): with dag_maker( dag_id='test_scheduler_fail_dagrun_timeout', dagrun_timeout=datetime.timedelta(seconds=60), + processor_subdir=TEST_DAG_FOLDER, session=session, ): EmptyOperator(task_id='dummy') @@ -1697,6 +1700,7 @@ def test_dagrun_timeout_fails_run(self, dag_maker): dag_id=dr.dag_id, is_failure_callback=True, run_id=dr.run_id, + processor_subdir=TEST_DAG_FOLDER, msg="timed_out", ) @@ -1751,6 +1755,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak dag_id='test_dagrun_callbacks_are_called', on_success_callback=lambda x: print("success"), on_failure_callback=lambda x: print("failed"), + processor_subdir=TEST_DAG_FOLDER, ) as dag: EmptyOperator(task_id='dummy') @@ -1773,6 +1778,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak dag_id=dr.dag_id, is_failure_callback=bool(state == State.FAILED), run_id=dr.run_id, + processor_subdir=TEST_DAG_FOLDER, msg=expected_callback_msg, ) @@ -1786,6 +1792,7 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio dag_id='test_dagrun_timeout_callbacks_are_stored_in_database', on_failure_callback=lambda x: print("failed"), dagrun_timeout=timedelta(hours=1), + processor_subdir=TEST_DAG_FOLDER, ) as dag: EmptyOperator(task_id='empty') @@ -1812,6 +1819,7 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio dag_id=dr.dag_id, is_failure_callback=True, run_id=dr.run_id, + processor_subdir=TEST_DAG_FOLDER, msg='timed_out', ) @@ -2996,7 +3004,11 @@ def test_send_sla_callbacks_to_processor_sla_no_task_slas(self, dag_maker): def test_send_sla_callbacks_to_processor_sla_with_task_slas(self, schedule, dag_maker): """Test SLA Callbacks are sent to the DAG Processor when SLAs are defined on tasks""" dag_id = 'test_send_sla_callbacks_to_processor_sla_with_task_slas' - with dag_maker(dag_id=dag_id, schedule=schedule) as dag: + with dag_maker( + dag_id=dag_id, + schedule=schedule, + processor_subdir=TEST_DAG_FOLDER, + ) as dag: EmptyOperator(task_id='task1', sla=timedelta(seconds=60)) with patch.object(settings, "CHECK_SLAS", True): @@ -3005,7 +3017,11 @@ def test_send_sla_callbacks_to_processor_sla_with_task_slas(self, schedule, dag_ self.scheduler_job._send_sla_callbacks_to_processor(dag) - expected_callback = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id) + expected_callback = SlaCallbackRequest( + full_filepath=dag.fileloc, + dag_id=dag.dag_id, + processor_subdir=TEST_DAG_FOLDER, + ) self.scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback) @pytest.mark.parametrize( @@ -3284,7 +3300,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak ) assert dr is not None # Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file - DAG.bulk_write_to_db([dag], session) + DAG.bulk_write_to_db([dag], session=session) # Test that 'dag_model.next_dagrun' has not been changed because of newly created external # triggered DagRun. @@ -4134,7 +4150,7 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce session = settings.Session() session.query(LocalTaskJob).delete() dag = dagbag.get_dag('test_example_bash_operator') - dag.sync_to_db() + dag.sync_to_db(processor_subdir=TEST_DAG_FOLDER) dag_run = dag.create_dagrun( state=DagRunState.RUNNING, @@ -4156,14 +4172,6 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce ti.job_id = local_job.id session.flush() - expected_failure_callback_requests = [ - TaskCallbackRequest( - full_filepath=dag.fileloc, - simple_task_instance=SimpleTaskInstance.from_ti(ti), - msg="Message", - ) - ] - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.executor = MockExecutor() self.scheduler_job.processor_agent = mock.MagicMock() @@ -4171,10 +4179,53 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce self.scheduler_job._find_zombies(session=session) self.scheduler_job.executor.callback_sink.send.assert_called_once() + + expected_failure_callback_requests = [ + TaskCallbackRequest( + full_filepath=dag.fileloc, + simple_task_instance=SimpleTaskInstance.from_ti(ti), + processor_subdir=TEST_DAG_FOLDER, + msg=str(self.scheduler_job._generate_zombie_message_details(ti)), + ) + ] callback_requests = self.scheduler_job.executor.callback_sink.send.call_args[0] + assert len(callback_requests) == 1 assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { result.simple_task_instance.key for result in callback_requests } + expected_failure_callback_requests[0].simple_task_instance = None + callback_requests[0].simple_task_instance = None + assert expected_failure_callback_requests[0] == callback_requests[0] + + def test_cleanup_stale_dags(self): + dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) + with create_session() as session: + dag = dagbag.get_dag('test_example_bash_operator') + dag.sync_to_db() + dm = DagModel.get_current('test_example_bash_operator') + # Make it "stale". + dm.last_parsed_time = timezone.utcnow() - timedelta(minutes=11) + session.merge(dm) + + # This one should remain active. + dag = dagbag.get_dag('test_start_date_scheduling') + dag.sync_to_db() + + session.flush() + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.executor = MockExecutor() + self.scheduler_job.processor_agent = mock.MagicMock() + + active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + assert active_dag_count == 2 + + self.scheduler_job._cleanup_stale_dags(session) + + session.flush() + + active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + assert active_dag_count == 1 @mock.patch.object(settings, 'USE_JOB_SCHEDULE', False) def run_scheduler_until_dagrun_terminal(self, job: SchedulerJob): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 173a08fb70829..78ec31d3f45ff 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -798,7 +798,7 @@ def test_bulk_write_to_db_max_active_runs(self, state): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session) + DAG.bulk_write_to_db([dag], session=session) model = session.query(DagModel).get((dag.dag_id,)) @@ -832,7 +832,7 @@ def test_bulk_write_to_db_has_import_error(self): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session) + DAG.bulk_write_to_db([dag], session=session) model = session.query(DagModel).get((dag.dag_id,)) @@ -871,7 +871,7 @@ def test_bulk_write_to_db_datasets(self): EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, extra={"should": "be used"})]) session = settings.Session() dag1.clear() - DAG.bulk_write_to_db([dag1, dag2], session) + DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()} d1 = stored_datasets[d1.uri] diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index c2a4f2aafa0be..b9759f64a078d 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -815,9 +815,9 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully mock_bulk_write_to_db.assert_has_calls( [ - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), + mock.call(mock.ANY, processor_subdir=None, session=mock.ANY), ] ) # Assert that rollback is called twice (i.e. whenever OperationalError occurs) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 5e7c9e804654d..60858765ffd29 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -426,6 +426,8 @@ def on_success_callable(context): start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) + DAG.bulk_write_to_db(dags=[dag], processor_subdir='/tmp/test', session=session) + dag_task1 = EmptyOperator(task_id='test_state_succeeded1', dag=dag) dag_task2 = EmptyOperator(task_id='test_state_succeeded2', dag=dag) dag_task1.set_downstream(dag_task2) @@ -449,6 +451,7 @@ def on_success_callable(context): dag_id="test_dagrun_update_state_with_handle_callback_success", run_id=dag_run.run_id, is_failure_callback=False, + processor_subdir='/tmp/test', msg="success", ) @@ -461,6 +464,8 @@ def on_failure_callable(context): start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) + DAG.bulk_write_to_db(dags=[dag], processor_subdir='/tmp/test', session=session) + dag_task1 = EmptyOperator(task_id='test_state_succeeded1', dag=dag) dag_task2 = EmptyOperator(task_id='test_state_failed2', dag=dag) dag_task1.set_downstream(dag_task2) @@ -484,6 +489,7 @@ def on_failure_callable(context): dag_id="test_dagrun_update_state_with_handle_callback_failure", run_id=dag_run.run_id, is_failure_callback=True, + processor_subdir='/tmp/test', msg="task_failure", ) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index c9b3a63cfef17..121474747e202 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -90,7 +90,7 @@ def test_write_dag(self): # Verifies JSON schema. SerializedDAG.validate_schema(result.data) - def test_serialized_dag_is_updated_only_if_dag_is_changed(self): + def test_serialized_dag_is_updated_if_dag_is_changed(self): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) example_bash_op_dag = example_dags.get("example_bash_operator") @@ -121,6 +121,33 @@ def test_serialized_dag_is_updated_only_if_dag_is_changed(self): assert s_dag_2.data["dag"]["tags"] == ["example", "example2", "new_tag"] assert dag_updated is True + def test_serialized_dag_is_updated_if_processor_subdir_changed(self): + """Test Serialized DAG is updated if processor_subdir is changed""" + example_dags = make_example_dags(example_dags_module) + example_bash_op_dag = example_dags.get("example_bash_operator") + dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/test') + assert dag_updated is True + + with create_session() as session: + s_dag = session.query(SDM).get(example_bash_op_dag.dag_id) + + # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated + # column is not updated + dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/test') + s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id) + + assert s_dag_1.dag_hash == s_dag.dag_hash + assert s_dag.last_updated == s_dag_1.last_updated + assert dag_updated is False + session.flush() + + # Update DAG + dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/other') + s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id) + + assert s_dag.processor_subdir != s_dag_2.processor_subdir + assert dag_updated is True + def test_read_dags(self): """DAGs can be read from database.""" example_dags = self._write_example_dags() diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py index 37cf0fe14e09f..480c2e7e263d8 100644 --- a/tests/test_utils/perf/perf_kit/sqlalchemy.py +++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py @@ -231,7 +231,7 @@ def case(): }, ): log = logging.getLogger(__name__) - processor = DagFileProcessor(dag_ids=[], log=log) + processor = DagFileProcessor(dag_ids=[], dag_directory="/tmp", log=log) dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py") processor.process_file(file_path=dag_file, callback_requests=[]) diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index e17586e5f8af6..27c37a9292793 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -121,7 +121,7 @@ def client_single_dag(app, user_single_dag): def _process_file(file_path, session): - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory='/tmp', log=mock.MagicMock()) dag_file_processor.process_file(file_path, [], False, session)