Skip to content

Commit

Permalink
Notes stored in separate table (apache#27849)
Browse files Browse the repository at this point in the history
* wip

* try revert rename

* simplify

* working, minimally

* more reverting of notes -> note rename

* more reverting of notes -> note rename

* more reverting of notes -> note rename

* remove scratch code

* remove test speedup

* restore admin view

* add migration

* add migration

* tod

* fix migration

* Add DagRunNote

* add migration file

* disamble notes in search

* fix dagrun tests

* fix some tests and tighten up relationships, i think

* remove notes from create_dagrun method

* more cleanup

* fix collation

* fix db cleanup test

* more test fixup

* more test fixup

* rename to tinote

* rename fixup

* Don't import FAB user models just to define FK rel

We don't (currently) define any relationships it's just for making the
FK match the migration, so for now we can have the FK col defined as a
string.

When we eventually add a relationship to the get the creator of the
note, we should move the FAB User model into airflow.models and change
Security manager code to import from there instead.

* Avoid touching test file unnecessarily

* fix import

* Apply suggestions from code review

* Test that a user_id is set when creating note via api

* Fix static checks

Co-authored-by: Ash Berlin-Taylor <[email protected]>
Co-authored-by: Jed Cunningham <[email protected]>
Co-authored-by: Jed Cunningham <[email protected]>
  • Loading branch information
4 people authored Nov 24, 2022
1 parent b18f319 commit adccc2d
Show file tree
Hide file tree
Showing 23 changed files with 1,602 additions and 1,397 deletions.
5 changes: 0 additions & 5 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _trigger_dag(
conf: dict | str | None = None,
execution_date: datetime | None = None,
replace_microseconds: bool = True,
notes: str | None = None,
) -> list[DagRun | None]:
"""Triggers DAG run.
Expand Down Expand Up @@ -93,7 +92,6 @@ def _trigger_dag(
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id),
data_interval=data_interval,
notes=notes,
)
dag_runs.append(dag_run)

Expand All @@ -106,7 +104,6 @@ def trigger_dag(
conf: dict | str | None = None,
execution_date: datetime | None = None,
replace_microseconds: bool = True,
notes: str | None = None,
) -> DagRun | None:
"""Triggers execution of DAG specified by dag_id.
Expand All @@ -115,7 +112,6 @@ def trigger_dag(
:param conf: configuration
:param execution_date: date of execution
:param replace_microseconds: whether microseconds should be zeroed
:param notes: set a custom note for the newly created DagRun
:return: first dag run triggered - even if more than one Dag Runs were triggered or None
"""
dag_model = DagModel.get_current(dag_id)
Expand All @@ -130,7 +126,6 @@ def trigger_dag(
conf=conf,
execution_date=execution_date,
replace_microseconds=replace_microseconds,
notes=notes,
)

return triggers[0] if triggers else None
12 changes: 10 additions & 2 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
session=session,
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
Expand Down Expand Up @@ -412,7 +413,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
include_parentdag=True,
only_failed=False,
)
dag_run.refresh_from_db()
dag_run = session.query(DagRun).filter(DagRun.id == dag_run.id).one()
return dagrun_schema.dump(dag_run)


Expand All @@ -437,6 +438,13 @@ def set_dag_run_notes(*, dag_id: str, dag_run_id: str, session: Session = NEW_SE
except ValidationError as err:
raise BadRequest(detail=str(err))

dag_run.notes = new_value_for_notes or None
from flask_login import current_user

current_user_id = getattr(current_user, "id", None)
if dag_run.dag_run_note is None:
dag_run.notes = (new_value_for_notes, current_user_id)
else:
dag_run.dag_run_note.content = new_value_for_notes
dag_run.dag_run_note.user_id = current_user_id
session.commit()
return dagrun_schema.dump(dag_run)
10 changes: 8 additions & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,13 @@ def set_task_instance_notes(
raise NotFound(error_message)

ti, sla_miss = result
ti.notes = new_value_for_notes or None
session.commit()
from flask_login import current_user

current_user_id = getattr(current_user, "id", None)
if ti.task_instance_note is None:
ti.notes = (new_value_for_notes, current_user_id)
else:
ti.task_instance_note.content = new_value_for_notes
ti.task_instance_note.user_id = current_user_id
session.commit()
return task_instance_schema.dump((ti, sla_miss))
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Add DagRunNote and TaskInstanceNote
Revision ID: 1986afd32c1b
Revises: ee8d93fcc81e
Create Date: 2022-11-22 21:49:05.843439
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

from airflow.migrations.db_types import StringID
from airflow.utils.sqlalchemy import UtcDateTime

# revision identifiers, used by Alembic.
revision = "1986afd32c1b"
down_revision = "ee8d93fcc81e"
branch_labels = None
depends_on = None
airflow_version = "2.5.0"


def upgrade():
"""Apply Add DagRunNote and TaskInstanceNote"""
op.create_table(
"dag_run_note",
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("dag_run_id", sa.Integer(), nullable=False),
sa.Column(
"content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True
),
sa.Column("created_at", UtcDateTime(timezone=True), nullable=False),
sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
("dag_run_id",), ["dag_run.id"], name="dag_run_note_dr_fkey", ondelete="CASCADE"
),
sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="dag_run_note_user_fkey"),
sa.PrimaryKeyConstraint("dag_run_id", name=op.f("dag_run_note_pkey")),
)

op.create_table(
"task_instance_note",
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("task_id", StringID(), nullable=False),
sa.Column("dag_id", StringID(), nullable=False),
sa.Column("run_id", StringID(), nullable=False),
sa.Column("map_index", sa.Integer(), nullable=False),
sa.Column(
"content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True
),
sa.Column("created_at", UtcDateTime(timezone=True), nullable=False),
sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint(
"task_id", "dag_id", "run_id", "map_index", name=op.f("task_instance_note_pkey")
),
sa.ForeignKeyConstraint(
("dag_id", "task_id", "run_id", "map_index"),
[
"task_instance.dag_id",
"task_instance.task_id",
"task_instance.run_id",
"task_instance.map_index",
],
name="task_instance_note_ti_fkey",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="task_instance_note_user_fkey"),
)


def downgrade():
"""Unapply Add DagRunNote and TaskInstanceNote"""
op.drop_table("task_instance_note")
op.drop_table("dag_run_note")

This file was deleted.

3 changes: 0 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2552,7 +2552,6 @@ def create_dagrun(
dag_hash: str | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
notes: str | None = None,
):
"""
Creates a dag run from this dag including the tasks associated with this dag.
Expand All @@ -2569,7 +2568,6 @@ def create_dagrun(
:param session: database session
:param dag_hash: Hash of Serialized DAG
:param data_interval: Data interval of the DagRun
:param notes: A custom note for the DAGRun.
"""
logical_date = timezone.coerce_datetime(execution_date)

Expand Down Expand Up @@ -2628,7 +2626,6 @@ def create_dagrun(
dag_hash=dag_hash,
creating_job_id=creating_job_id,
data_interval=data_interval,
notes=notes,
)
session.add(run)
session.flush()
Expand Down
55 changes: 52 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Boolean,
Column,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
PickleType,
Expand All @@ -40,6 +41,7 @@
text,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import joinedload, relationship, synonym
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -85,6 +87,16 @@ class TISchedulingDecision(NamedTuple):
finished_tis: list[TI]


def _creator_note(val):
"""Custom creator for the ``note`` association proxy."""
if isinstance(val, str):
return DagRunNote(content=val)
elif isinstance(val, dict):
return DagRunNote(**val)
else:
return DagRunNote(*val)


class DagRun(Base, LoggingMixin):
"""
DagRun describes an instance of a Dag. It can be created
Expand All @@ -111,7 +123,6 @@ class DagRun(Base, LoggingMixin):
# When a scheduler last attempted to schedule TIs for this DagRun
last_scheduling_decision = Column(UtcDateTime)
dag_hash = Column(String(32))
notes = Column(String(1000).with_variant(Text(1000), "mysql"))
# Foreign key to LogTemplate. DagRun rows created prior to this column's
# existence have this set to NULL. Later rows automatically populate this on
# insert to point to the latest LogTemplate entry.
Expand Down Expand Up @@ -163,6 +174,8 @@ class DagRun(Base, LoggingMixin):
uselist=False,
viewonly=True,
)
dag_run_note = relationship("DagRunNote", back_populates="dag_run", uselist=False)
notes = association_proxy("dag_run_note", "content", creator=_creator_note)

DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
"scheduler",
Expand All @@ -184,7 +197,6 @@ def __init__(
dag_hash: str | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
notes: str | None = None,
):
if data_interval is None:
# Legacy: Only happen for runs created prior to Airflow 2.2.
Expand All @@ -207,7 +219,6 @@ def __init__(
self.run_type = run_type
self.dag_hash = dag_hash
self.creating_job_id = creating_job_id
self.notes = notes
super().__init__()

def __repr__(self):
Expand Down Expand Up @@ -1295,3 +1306,41 @@ def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
stacklevel=2,
)
return self.get_log_template(session=session).filename


class DagRunNote(Base):
"""For storage of arbitrary notes concerning the dagrun instance."""

__tablename__ = "dag_run_note"

user_id = Column(Integer, nullable=True)
dag_run_id = Column(Integer, primary_key=True, nullable=False)
content = Column(String(1000).with_variant(Text(1000), "mysql"))
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)

dag_run = relationship("DagRun", back_populates="dag_run_note")

__table_args__ = (
ForeignKeyConstraint(
(dag_run_id,),
["dag_run.id"],
name="dag_run_note_dr_fkey",
ondelete="CASCADE",
),
ForeignKeyConstraint(
(user_id,),
["ab_user.id"],
name="dag_run_note_user_fkey",
),
)

def __init__(self, content, user_id=None):
self.content = content
self.user_id = user_id

def __repr__(self):
prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}"
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
return prefix + ">"
Loading

0 comments on commit adccc2d

Please sign in to comment.