Skip to content

Commit

Permalink
Don't Update Serialized DAGs in DB if DAG didn't change (apache#9850)
Browse files Browse the repository at this point in the history
We should not update the "last_updated" column unnecessarily. This is first of  few optimizations to DAG Serialization that would also aid in DAG Versioning
  • Loading branch information
kaxil authored Jul 20, 2020
1 parent a0bde8e commit 1a32c45
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
38 changes: 27 additions & 11 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# specific language governing permissions and limitations
# under the License.

"""Serialzed DAG table in database."""
"""Serialized DAG table in database."""

import logging
from datetime import timedelta
from typing import Any, Dict, List, Optional

import sqlalchemy_jsonfield
from sqlalchemy import BigInteger, Column, Index, String, and_
from sqlalchemy.orm import Session
from sqlalchemy.sql import exists

from airflow.models.base import ID_LEN, Base
Expand Down Expand Up @@ -81,8 +82,10 @@ def __repr__(self):

@classmethod
@provide_session
def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session=None):
def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: Session = None):
"""Serializes a DAG and writes it into database.
If the record already exists, it checks if the Serialized DAG changed or not. If it is
changed, it updates the record, ignores otherwise.
:param dag: a DAG to be written into database
:param min_update_interval: minimal interval in seconds to update serialized DAG
Expand All @@ -97,13 +100,21 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session=
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated))
).scalar():
return
log.debug("Writing DAG: %s to the DB", dag.dag_id)
session.merge(cls(dag))

log.debug("Checking if DAG (%s) changed", dag.dag_id)
serialized_dag_from_db: SerializedDagModel = session.query(cls).get(dag.dag_id)
new_serialized_dag = cls(dag)
if serialized_dag_from_db and (serialized_dag_from_db.data == new_serialized_dag.data):
log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
return

log.debug("Writing Serialized DAG: %s to the DB", dag.dag_id)
session.merge(new_serialized_dag)
log.debug("DAG: %s written to the DB", dag.dag_id)

@classmethod
@provide_session
def read_all_dags(cls, session=None) -> Dict[str, 'SerializedDAG']:
def read_all_dags(cls, session: Session = None) -> Dict[str, 'SerializedDAG']:
"""Reads all DAGs in serialized_dag table.
:param session: ORM Session
Expand Down Expand Up @@ -137,7 +148,7 @@ def dag(self):

@classmethod
@provide_session
def remove_dag(cls, dag_id: str, session=None):
def remove_dag(cls, dag_id: str, session: Session = None):
"""Deletes a DAG with given dag_id.
:param dag_id: dag_id to be deleted
Expand All @@ -148,14 +159,16 @@ def remove_dag(cls, dag_id: str, session=None):

@classmethod
@provide_session
def remove_stale_dags(cls, expiration_date, session=None):
def remove_stale_dags(cls, expiration_date, session: Session = None):
"""
Deletes Serialized DAGs that were last touched by the scheduler before
the expiration date. These DAGs were likely deleted.
:param expiration_date: set inactive DAGs that were touched before this
time
:type expiration_date: datetime
:param session: ORM Session
:type session: Session
:return: None
"""
log.debug("Deleting Serialized DAGs that haven't been touched by the "
Expand All @@ -168,7 +181,7 @@ def remove_stale_dags(cls, expiration_date, session=None):

@classmethod
@provide_session
def has_dag(cls, dag_id: str, session=None) -> bool:
def has_dag(cls, dag_id: str, session: Session = None) -> bool:
"""Checks a DAG exist in serialized_dag table.
:param dag_id: the DAG to check
Expand All @@ -178,7 +191,7 @@ def has_dag(cls, dag_id: str, session=None) -> bool:

@classmethod
@provide_session
def get(cls, dag_id: str, session=None) -> Optional['SerializedDagModel']:
def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']:
"""
Get the SerializedDAG for the given dag ID.
It will cope with being passed the ID of a subdag by looking up the
Expand All @@ -200,12 +213,15 @@ def get(cls, dag_id: str, session=None) -> Optional['SerializedDagModel']:

@staticmethod
@provide_session
def bulk_sync_to_db(dags: List[DAG], session=None):
def bulk_sync_to_db(dags: List[DAG], session: Session = None):
"""
Saves DAGs as Seralized DAG objects in the database. Each DAG is saved in a separate database query.
Saves DAGs as Serialized DAG objects in the database. Each
DAG is saved in a separate database query.
:param dags: the DAG objects to save to the DB
:type dags: List[airflow.models.dag.DAG]
:param session: ORM Session
:type session: Session
:return: None
"""
for dag in dags:
Expand Down
32 changes: 31 additions & 1 deletion tests/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,36 @@ def test_write_dag(self):
# Verifies JSON schema.
SerializedDAG.validate_schema(result.data)

def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
"""Test Serialized DAG is updated if DAG is changed"""

example_dags = make_example_dags(example_dags_module)
example_bash_op_dag = example_dags.get("example_bash_operator")
SDM.write_dag(dag=example_bash_op_dag)

with create_session() as session:
last_updated = session.query(
SDM.last_updated).filter(SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

# Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
# column is not updated
SDM.write_dag(dag=example_bash_op_dag)
last_updated_1 = session.query(
SDM.last_updated).filter(SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

self.assertEqual(last_updated, last_updated_1)

# Update DAG
example_bash_op_dag.tags += ["new_tag"]
self.assertCountEqual(example_bash_op_dag.tags, ["example", "new_tag"])

SDM.write_dag(dag=example_bash_op_dag)
new_s_dag = session.query(SDM.last_updated, SDM.data).filter(
SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

self.assertNotEqual(last_updated, new_s_dag.last_updated)
self.assertEqual(new_s_dag.data["dag"]["tags"], ["example", "new_tag"])

def test_read_dags(self):
"""DAGs can be read from database."""
example_dags = self._write_example_dags()
Expand Down Expand Up @@ -120,5 +150,5 @@ def test_bulk_sync_to_db(self):
dags = [
DAG("dag_1"), DAG("dag_2"), DAG("dag_3"),
]
with assert_queries_count(7):
with assert_queries_count(10):
SDM.bulk_sync_to_db(dags)

0 comments on commit 1a32c45

Please sign in to comment.