Skip to content

Commit

Permalink
Unit tests!
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Feb 9, 2015
1 parent f4c5963 commit bdb205a
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 58 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
docs/_*
airflow.db
build
cover
.coverage
dist
env
initdb.py
Expand Down
37 changes: 1 addition & 36 deletions airflow/bin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,42 +266,7 @@ def initdb(args):
"Proceed? (y/n)").upper() == "Y":
logging.basicConfig(level=logging.DEBUG,
format=settings.SIMPLE_LOG_FORMAT)

from airflow import models

logging.info("Dropping tables that exist")
models.Base.metadata.drop_all(settings.engine)

logging.info("Creating all tables")
models.Base.metadata.create_all(settings.engine)

# Creating the local_mysql DB connection
session = settings.Session()
session.query(models.Connection).delete()
session.add(
models.Connection(
conn_id='local_mysql', conn_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.Connection(
conn_id='mysql_default', conn_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.Connection(
conn_id='presto_default', conn_type='presto',
host='localhost',
schema='hive', port=10001))
session.commit()
session.add(
models.Connection(
conn_id='hive_default', conn_type='hive',
host='localhost',
schema='default', port=10000))
session.commit()
utils.resetdb()
else:
print("Bail.")

Expand Down
51 changes: 51 additions & 0 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,44 @@
id_len = 250
"""

TEST_CONFIG = """\
[core]
airflow_home = {AIRFLOW_HOME}
authenticate = False
dags_folder = {AIRFLOW_HOME}/dags
base_log_folder = {AIRFLOW_HOME}/logs
base_url = http://localhost:8080
executor = SequentialExecutor
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/tests.db
[server]
web_server_host = 0.0.0.0
web_server_port = 8080
[smtp]
smtp_host = localhost
smtp_user = airflow
smtp_port = 25
smtp_password = airflow
smtp_mail_from = [email protected]
[celery]
celery_app_name = airflow.executors.celery_executor
celeryd_concurrency = 16
worker_log_server_port = 8793
broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow
celery_result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
flower_port = 5555
[hooks]
presto_default_conn_id = presto_default
hive_default_conn_id = hive_default
[misc]
job_heartbeat_sec = 1
id_len = 250
"""

def mkdir_p(path):
try:
os.makedirs(path)
Expand Down Expand Up @@ -78,5 +116,18 @@ def mkdir_p(path):
f.write(DEFAULT_CONFIG.format(**locals()))
f.close()

TEST_CONFIG_FILE = AIRFLOW_HOME + '/unittests.cfg'
if not os.path.isfile(TEST_CONFIG_FILE):
logging.info("Createing new config file in: " + TEST_CONFIG_FILE)
f = open(TEST_CONFIG_FILE, 'w')
f.write(TEST_CONFIG.format(**locals()))
f.close()

logging.info("Reading the config from " + AIRFLOW_CONFIG)

def test_mode():
conf = ConfigParser()
conf.read(TEST_CONFIG)
print("Using configuration located at: " + TEST_CONFIG)

conf.read(AIRFLOW_CONFIG)
4 changes: 2 additions & 2 deletions airflow/example_dags/example_bash_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
default_args=args)
dag.add_task(run_this)
run_this.set_downstream(run_this_last)
for i in range(9):
for i in range(5):
i = str(i)
task = BashOperator(
task_id='runme_'+i,
bash_command='echo "{{ task_instance_key_str }}" && sleep 5',
bash_command='echo "{{ task_instance_key_str }}" && sleep ' + str(i),
default_args=args)
task.set_downstream(run_this)
dag.add_task(task)
Expand Down
7 changes: 3 additions & 4 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self):
self.commands = {}
self.event_buffer = {}

def start(self):
def start(self): # pragma: no cover
"""
Executors may need to get things started. For example LocalExecutor
starts N workers.
Expand All @@ -36,17 +36,16 @@ def get_event_buffer(self):
self.event_buffer = {}
return d

def execute_async(self, key, command):
def execute_async(self, key, command): # pragma: no cover
"""
This method will execute the command asynchronously.
"""
raise NotImplementedError()

def end(self):
def end(self): # pragma: no cover
"""
This method is called when the caller is done submitting job and is
wants to wait synchronously for the job submitted previously to be
all done.
"""
raise NotImplementedError()

6 changes: 6 additions & 0 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ def ds_add(ds, days):
:type ds: str
:param days: number of days to add to the ds, you can use negative values
:type days: int
>>> ds_add('2015-01-01', 5)
'2015-01-06'
>>> ds_add('2015-01-06', -5)
'2015-01-01'
'''

ds = datetime.strptime(ds, '%Y-%m-%d')
if days:
ds = ds + timedelta(days)
Expand Down
41 changes: 40 additions & 1 deletion airflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlalchemy.pool import Pool

from airflow.configuration import conf
from airflow import settings


class State(object):
Expand Down Expand Up @@ -57,7 +58,45 @@ def ping_connection(dbapi_connection, connection_record, connection_proxy):
raise exc.DisconnectionError()
cursor.close()


def resetdb():
'''
Clear out the database
'''
from airflow import models

logging.info("Dropping tables that exist")
models.Base.metadata.drop_all(settings.engine)

logging.info("Creating all tables")
models.Base.metadata.create_all(settings.engine)

# Creating the local_mysql DB connection
session = settings.Session()
session.query(models.Connection).delete()
session.add(
models.Connection(
conn_id='local_mysql', conn_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.Connection(
conn_id='mysql_default', conn_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
session.commit()
session.add(
models.Connection(
conn_id='presto_default', conn_type='presto',
host='localhost',
schema='hive', port=10001))
session.commit()
session.add(
models.Connection(
conn_id='hive_default', conn_type='hive',
host='localhost',
schema='default', port=10000))
session.commit()

def validate_key(k, max_length=250):
if type(k) is not str:
Expand Down
3 changes: 3 additions & 0 deletions run_unit_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export AIRFLOW_CONFIG=~/airflow/unittests.cfg
nosetests --with-doctest --with-coverage --cover-html --cover-package=airflow
python -m SimpleHTTPServer 8001
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from core import *
99 changes: 84 additions & 15 deletions tests/core.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,101 @@
from datetime import datetime
import unittest

import airflow
from airflow import configuration
configuration.test_mode()
from airflow import jobs, models, executors, utils
from airflow.www.app import app

NUM_EXAMPLE_DAGS = 3
DEV_NULL = '/dev/null'
LOCAL_EXECUTOR = airflow.executors.LocalExecutor()
LOCAL_EXECUTOR = executors.LocalExecutor()
DEFAULT_DATE = datetime(2015, 1, 1)


class CoreTest(unittest.TestCase):

def setUp(self):
pass
configuration.test_mode()
utils.resetdb()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')

def test_import_examples(self):
dagbag = airflow.models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.assertEqual(len(dagbag.dags), NUM_EXAMPLE_DAGS)
self.assertEqual(len(self.dagbag.dags), NUM_EXAMPLE_DAGS)

def test_backfill_example1(self):
dagbag = airflow.models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
dag = dagbag.dags['example1']
TI = airflow.models.TaskInstance
ti = TI.get_or_create(
task=dag.get_task('runme_0'), execution_date=DEFAULT_DATE)
airflow.jobs.RunJob(task_instance=ti, force=True)
def test_local_task_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
job = jobs.LocalTaskJob(task_instance=ti, force=True)
job.run()

def test_local_backfill_job(self):
self.dag_bash.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE)
job = jobs.BackfillJob(
dag=self.dag_bash,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE)
job.run()

def test_raw_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(force=True)


class WebUiTests(unittest.TestCase):

def setUp(self):
configuration.test_mode()
app.config['TESTING'] = True
self.app = app.test_client()

def test_index(self):
response = self.app.get('/', follow_redirects=True)
assert "DAGs" in response.data
assert "example_bash_operator" in response.data

def test_query(self):
response = self.app.get('/admin/airflow/query')
assert "Ad Hoc Query" in response.data
response = self.app.get(
"/admin/airflow/query?conn_id=local_mysql&sql=SELECT+*+FROM+dag")
assert "example_bash_operator" in response.data


def test_health(self):
response = self.app.get('/health')
assert 'The server is healthy!' in response.data

def test_dag_views(self):
response = self.app.get(
'/admin/airflow/graph?dag_id=example_bash_operator')
assert "runme_0" in response.data
response = self.app.get(
'/admin/airflow/tree?num_runs=25&dag_id=example_bash_operator')
assert "runme_0" in response.data
response = self.app.get(
'/admin/airflow/duration?days=30&dag_id=example_bash_operator')
assert "DAG: example_bash_operator" in response.data
response = self.app.get(
'/admin/airflow/landing_times?'
'days=30&dag_id=example_bash_operator')
assert "DAG: example_bash_operator" in response.data
response = self.app.get(
'/admin/airflow/gantt?dag_id=example_bash_operator')
assert "DAG: example_bash_operator" in response.data
response = self.app.get(
'/admin/airflow/code?dag_id=example_bash_operator')
assert "DAG: example_bash_operator" in response.data

def tearDown(self):
pass


if __name__ == '__main__':
Expand Down

0 comments on commit bdb205a

Please sign in to comment.