Skip to content

Commit

Permalink
More unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Feb 10, 2015
1 parent fdb8149 commit 9d78158
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 13 deletions.
4 changes: 2 additions & 2 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
base_log_folder = {AIRFLOW_HOME}/logs
base_url = http://localhost:8080
executor = SequentialExecutor
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/tests.db
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db
unit_test_mode = True
[server]
Expand Down Expand Up @@ -108,7 +108,6 @@ def mkdir_p(path):
else:
AIRFLOW_CONFIG = os.environ['AIRFLOW_CONFIG']

conf = ConfigParser()
if not os.path.isfile(AIRFLOW_CONFIG):
'''
These configuration are used to generate a default configuration when
Expand All @@ -135,4 +134,5 @@ def test_mode():
conf.read(TEST_CONFIG)
print("Using configuration located at: " + TEST_CONFIG)

conf = ConfigParser()
conf.read(AIRFLOW_CONFIG)
51 changes: 51 additions & 0 deletions airflow/hooks/hive_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def get_conn(self):
return self.hive

def check_for_partition(self, schema, table, partition):
'''
Checks whether a partition exists
>>> hh = HiveHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "year='2008'")
True
'''
self.hive._oprot.trans.open()
partitions = self.hive.get_partitions_by_filter(
schema, table, partition, 1)
Expand All @@ -82,6 +90,11 @@ def check_for_partition(self, schema, table, partition):
def get_records(self, hql, schema=None):
'''
Get a set of records from a Hive query.
>>> hh = HiveHook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> hh.get_records(sql)
[['340698']]
'''
self.hive._oprot.trans.open()
if schema:
Expand All @@ -92,6 +105,15 @@ def get_records(self, hql, schema=None):
return [row.split("\t") for row in records]

def get_pandas_df(self, hql, schema=None):
'''
Get a pandas dataframe from a Hive query
>>> hh = HiveHook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> df = hh.get_pandas_df(sql)
>>> df.to_dict()
{0: {0: '340698'}}
'''
import pandas as pd
self.hive._oprot.trans.open()
if schema:
Expand All @@ -110,6 +132,12 @@ def run(self, hql, schema=None):
self.hive._oprot.trans.close()

def run_cli(self, hql, schema=None):
'''
Run an hql statement using the hive cli
>>> hh = HiveHook()
>>> hh.run_cli("USE airflow;")
'''
if schema:
hql = "USE {schema};\n{hql}".format(**locals())

Expand All @@ -132,6 +160,16 @@ def run_cli(self, hql, schema=None):
raise Exception(all_err)

def get_table(self, db, table_name):
'''
Get a metastore table object
>>> hh = HiveHook()
>>> t = hh.get_table(db='airflow', table_name='static_babynames')
>>> t.tableName
'static_babynames'
>>> [col.name for col in t.sd.cols]
['state', 'year', 'name', 'gender', 'num']
'''
self.hive._oprot.trans.open()
table = self.hive.get_table(dbname=db, tbl_name=table_name)
self.hive._oprot.trans.close()
Expand All @@ -142,6 +180,14 @@ def get_partitions(self, schema, table_name):
Returns a list of all partitions in a table. Works only
for tables with less than 32767 (java short max val).
For subpartitionned table, the number might easily exceed this.
>>> hh = HiveHook()
>>> t = 'static_babynames_partitioned'
>>> parts = hh.get_partitions(schema='airflow', table_name=t)
>>> len(parts)
49
>>> max(parts)
'2008'
'''
self.hive._oprot.trans.open()
table = self.hive.get_table(dbname=schema, tbl_name=table_name)
Expand All @@ -163,6 +209,11 @@ def max_partition(self, schema, table_name):
Returns the maximum value for all partitions in a table. Works only
for tables that have a single partition key. For subpartitionned
table, we recommend using signal tables.
>>> hh = HiveHook()
>>> t = 'static_babynames_partitioned'
>>> hh.max_partition(schema='airflow', table_name=t)
'2008'
'''
return max(self.get_partitions(schema, table_name))

Expand Down
5 changes: 5 additions & 0 deletions airflow/hooks/presto_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
class PrestoHook(BaseHook):
"""
Interact with Presto through PyHive!
>>> ph = PrestoHook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> ph.get_records(sql)
[[340698]]
"""
def __init__(self, host=None, db=None, port=None,
presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID')):
Expand Down
1 change: 0 additions & 1 deletion airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def signal_handler(signum, frame):
executor.start()
i = 0
while (not self.test_mode) or i < 1:
print(i)
i += 1
self.heartbeat()
dagbag.collect_dags(only_if_updated=True)
Expand Down
14 changes: 6 additions & 8 deletions airflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import smtplib

from sqlalchemy import event
from sqlalchemy import event, exc
from sqlalchemy.pool import Pool

from airflow.configuration import conf
Expand Down Expand Up @@ -83,7 +83,7 @@ def initdb():
models.Connection(
conn_id='presto_default', conn_type='presto',
host='localhost',
schema='hive', port=10001))
schema='hive', port=3400))
session.commit()

conn = session.query(C).filter(C.conn_id == 'hive_default').first()
Expand Down Expand Up @@ -206,17 +206,17 @@ def wrapper(*args, **kwargs):


def ask_yesno(question):
yes = set(['yes','y',])
no = set(['no','n'])
yes = set(['yes', 'y'])
no = set(['no', 'n'])

done = False
print(question)
while not done:
choice = raw_input().lower()
if choice in yes:
return True
return True
elif choice in no:
return False
return False
else:
print("Please respond by yes or no.")

Expand Down Expand Up @@ -249,5 +249,3 @@ def send_email(to, subject, html_content):
logging.info("Sent an altert email to " + str(to))
s.sendmail(SMTP_MAIL_FROM, to, msg.as_string())
s.quit()


2 changes: 1 addition & 1 deletion run_unit_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export AIRFLOW_CONFIG=~/airflow/unittests.cfg
nosetests --with-doctest --with-coverage --cover-html --cover-package=airflow #--nocapture
nosetests --with-doctest --with-coverage --cover-html --cover-package=airflow --nocapture
#python -m SimpleHTTPServer 8002
9 changes: 8 additions & 1 deletion tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class CoreTest(unittest.TestCase):

def setUp(self):
configuration.test_mode()
print("INITDB")
utils.initdb()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
Expand Down Expand Up @@ -106,6 +105,14 @@ def test_dag_views(self):
'/admin/airflow/code?dag_id=example_bash_operator')
assert "DAG: example_bash_operator" in response.data

def test_charts(self):
response = self.app.get(
'/admin/airflow/chart?chart_id=1&iteration_no=1')
assert "Most Popular" in response.data
response = self.app.get(
'/admin/airflow/chart_data?chart_id=1&iteration_no=1')
assert "Michael" in response.data

def tearDown(self):
pass

Expand Down

0 comments on commit 9d78158

Please sign in to comment.