Skip to content

Commit

Permalink
[AIRFLOW-3861] Use the create_session for the db session (apache#4683)
Browse files Browse the repository at this point in the history
Instead of opening and closing these session manually, it is easier
to do this using the create_session and that should be the preferred
way as well.
  • Loading branch information
Fokko authored and kaxil committed Feb 12, 2019
1 parent 21a38f1 commit 907aa00
Showing 1 changed file with 84 additions and 108 deletions.
192 changes: 84 additions & 108 deletions tests/operators/test_python_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.settings import Session
from airflow.utils import timezone
from airflow.utils.db import create_session
from airflow.utils.state import State

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
Expand All @@ -51,12 +51,9 @@ class PythonOperatorTest(unittest.TestCase):
def setUpClass(cls):
super(PythonOperatorTest, cls).setUpClass()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

def setUp(self):
super(PythonOperatorTest, self).setUp()
Expand All @@ -74,13 +71,9 @@ def setUp(self):
def tearDown(self):
super(PythonOperatorTest, self).tearDown()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
print(len(session.query(DagRun).all()))
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

for var in TI_CONTEXT_ENV_VARS:
if var in os.environ:
Expand Down Expand Up @@ -170,12 +163,9 @@ class BranchOperatorTest(unittest.TestCase):
def setUpClass(cls):
super(BranchOperatorTest, cls).setUpClass()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

def setUp(self):
self.dag = DAG('branch_operator_test',
Expand All @@ -190,13 +180,9 @@ def setUp(self):
def tearDown(self):
super(BranchOperatorTest, self).tearDown()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
print(len(session.query(DagRun).all()))
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

def test_without_dag_run(self):
"""This checks the defensive against non existent tasks in a dag run"""
Expand All @@ -209,23 +195,22 @@ def test_without_dag_run(self):

self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

session = Session()
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)
session.close()

for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
# should exist with state None
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception
with create_session() as session:
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)

for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
# should exist with state None
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception

def test_branch_list_without_dag_run(self):
"""This checks if the BranchPythonOperator supports branching off to a list of tasks."""
Expand All @@ -240,25 +225,24 @@ def test_branch_list_without_dag_run(self):

self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

session = Session()
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)
session.close()

expected = {
"make_choice": State.SUCCESS,
"branch_1": State.NONE,
"branch_2": State.NONE,
"branch_3": State.SKIPPED,
}

for ti in tis:
if ti.task_id in expected:
self.assertEqual(ti.state, expected[ti.task_id])
else:
raise Exception
with create_session() as session:
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)

expected = {
"make_choice": State.SUCCESS,
"branch_1": State.NONE,
"branch_2": State.NONE,
"branch_3": State.SKIPPED,
}

for ti in tis:
if ti.task_id in expected:
self.assertEqual(ti.state, expected[ti.task_id])
else:
raise Exception

def test_with_dag_run(self):
self.branch_op = BranchPythonOperator(task_id='make_choice',
Expand Down Expand Up @@ -353,22 +337,16 @@ class ShortCircuitOperatorTest(unittest.TestCase):
def setUpClass(cls):
super(ShortCircuitOperatorTest, cls).setUpClass()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

def tearDown(self):
super(ShortCircuitOperatorTest, self).tearDown()

session = Session()

session.query(DagRun).delete()
session.query(TI).delete()
session.commit()
session.close()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()

def test_without_dag_run(self):
"""This checks the defensive against non existent tasks in a dag run"""
Expand All @@ -392,39 +370,37 @@ def test_without_dag_run(self):

short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

session = Session()
tis = session.query(TI).filter(
TI.dag_id == dag.dag_id,
TI.execution_date == DEFAULT_DATE
)

for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'upstream':
# should not exist
raise Exception
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception

value = True
dag.clear()

short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'upstream':
# should not exist
raise Exception
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception

session.close()
with create_session() as session:
tis = session.query(TI).filter(
TI.dag_id == dag.dag_id,
TI.execution_date == DEFAULT_DATE
)

for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'upstream':
# should not exist
raise Exception
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception

value = True
dag.clear()

short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'upstream':
# should not exist
raise Exception
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception

def test_with_dag_run(self):
value = False
Expand Down

0 comments on commit 907aa00

Please sign in to comment.