Skip to content

Commit

Permalink
updates to raise appropriate errors from athena client, etc (airbnb#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryandeivert authored Jul 2, 2018
1 parent 82e3e3f commit d1b6b7a
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 61 deletions.
2 changes: 2 additions & 0 deletions stream_alert/rule_promotion/promoter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def _update_alert_count(self):
row_values = [data.values()[0] for data in row['Data']]
rule_name, alert_count = row_values[0], int(row_values[1])

LOGGER.debug('Found %d alerts for rule \'%s\'', alert_count, rule_name)

self._staging_stats[rule_name].alert_count = alert_count

def run(self):
Expand Down
11 changes: 4 additions & 7 deletions stream_alert/rule_promotion/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,10 @@ def _query_alerts(self, stat):
Returns:
str: Execution ID for running Athena query
"""
# If there are no alerts, do not run the comprehensive query
if not stat.alert_count:
return

info_statement = stat.sql_info_statement
LOGGER.debug('Querying alert info for rule \'%s\': %s', stat.rule_name, info_statement)

response = self._athena_client.run_async_query(info_statement)
if not response:
LOGGER.error('Failed to query alert info for rule: \'%s\'', stat.rule_name)
return

return response['QueryExecutionId']

Expand Down Expand Up @@ -156,6 +149,10 @@ def publish(self, stats):
if self._should_send_digest:

for stat in stats:
# If there are no alerts, do not run the comprehensive query
if not stat:
continue

stat.execution_id = self._query_alerts(stat)

self._publish_message(stats)
Expand Down
10 changes: 9 additions & 1 deletion stream_alert/rule_promotion/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
class StagingStatistic(object):
"""Store information on generated alerts."""

_ALERT_COUNT_UNKOWN = 'unknown'

_COUNT_QUERY_TEMPLATE = ("SELECT '{rule_name}' AS rule_name, count(*) AS count "
"FROM alerts WHERE dt >= '{date}-{hour:02}' AND "
"rule_name = '{rule_name}'")
Expand All @@ -34,9 +36,15 @@ def __init__(self, staged_at, staged_until, current_time, rule):
self._current_time = current_time
self._staged_at = staged_at
self.staged_until = staged_until
self.alert_count = 'unknown'
self.alert_count = self._ALERT_COUNT_UNKOWN
self.execution_id = None

def __nonzero__(self):
return self.alert_count not in {0, self._ALERT_COUNT_UNKOWN}

# For forward compatibility to Python3
__bool__ = __nonzero__

def __lt__(self, other):
"""Statistic should be ordered by their alert count."""
return self.alert_count < other.alert_count
Expand Down
68 changes: 47 additions & 21 deletions stream_alert/shared/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
)


class AthenaQueryExecutionError(Exception):
"""Exception to be raised when an Athena query fails"""


class AthenaClient(object):
"""A StreamAlert Athena Client for creating tables, databases, and executing queries
Expand Down Expand Up @@ -84,17 +88,18 @@ def _execute_and_wait(self, query):
Returns:
str: Athena execution ID for the query that was executed
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
response = self._execute_query(query)
if not response:
return

exeuction_id = response['QueryExecutionId']

success = self.check_query_status(exeuction_id)
if not success:
LOGGER.error('Athena query failed:\n%s', query)
return
# This will block until the execution is complete, or raise an
# AthenaQueryExecutionError exception if an error occurs
self.check_query_status(exeuction_id)

return exeuction_id

Expand All @@ -106,6 +111,10 @@ def _execute_query(self, query):
Returns:
dict: Response object with the status of the running query
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
LOGGER.debug('Executing query: %s', query)
try:
Expand All @@ -114,8 +123,8 @@ def _execute_query(self, query):
QueryExecutionContext={'Database': self.database},
ResultConfiguration={'OutputLocation': self._s3_results_path}
)
except ClientError:
LOGGER.exception('Athena query failed')
except ClientError as err:
raise AthenaQueryExecutionError('Athena query failed:\n{}'.format(err))

def drop_all_tables(self):
"""Drop all table in the database
Expand Down Expand Up @@ -165,7 +174,7 @@ def get_table_partitions(self, table_name):

return self._unique_values_from_query(partitions)

def check_query_status(self, query_execution_id):
def check_query_status(self, execution_id):
"""Check in on the running query, back off if the job is running or queued
Args:
Expand All @@ -174,8 +183,12 @@ def check_query_status(self, query_execution_id):
Returns:
bool: True if the query state is SUCCEEDED, False otherwise
Reference https://bit.ly/2uuRtda.
Raises:
AthenaQueryExecutionError: If any failure occurs while checking the status of the
query, this exception will be raised
"""
LOGGER.debug('Checking status of query with execution ID: %s', query_execution_id)
LOGGER.debug('Checking status of query with execution ID: %s', execution_id)

states_to_backoff = {'QUEUED', 'RUNNING'}
@backoff.on_predicate(backoff.fibo,
Expand All @@ -190,14 +203,16 @@ def _check_status(query_execution_id):
QueryExecutionId=query_execution_id
)

execution_result = _check_status(query_execution_id)
execution_result = _check_status(execution_id)
state = execution_result['QueryExecution']['Status']['State']
if state != 'SUCCEEDED':
reason = execution_result['QueryExecution']['Status']['StateChangeReason']
LOGGER.error('Query %s %s with reason %s, exiting!', query_execution_id, state, reason)
return False
if state == 'SUCCEEDED':
return

return True
# When the state is not SUCCEEDED, something bad must have occurred, so raise an exception
reason = execution_result['QueryExecution']['Status']['StateChangeReason']
raise AthenaQueryExecutionError(
'Query \'{}\' {} with reason \'{}\', exiting'.format(execution_id, state, reason)
)

def query_result_paginator(self, query):
"""Iterate over all results returned by the Athena query. This is a blocking operation
Expand All @@ -207,10 +222,12 @@ def query_result_paginator(self, query):
Yields:
dict: Response objects with the results of the running query
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
execution_id = self._execute_and_wait(query)
if not execution_id:
return

paginator = self._client.get_paginator('get_query_results')

Expand All @@ -226,6 +243,10 @@ def run_async_query(self, query):
Returns:
dict: Response object with the status of the running query
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
return self._execute_query(query)

Expand All @@ -237,6 +258,10 @@ def run_query(self, query):
Returns:
bool: True if the query ran successfully, False otherwise
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
return bool(self._execute_and_wait(query))

Expand All @@ -248,11 +273,12 @@ def run_query_for_results(self, query):
Returns:
dict: Response object with the result of the running query
Raises:
AthenaQueryExecutionError: If any failure occurs during the execution of the
query, this exception will be raised
"""
execution_id = self._execute_and_wait(query)
if not execution_id:
return

query_results = self._client.get_query_results(QueryExecutionId=execution_id)

# The idea here is to leave the processing logic to the calling functions.
Expand Down
4 changes: 2 additions & 2 deletions stream_alert/shared/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def import_folders(*paths):


class RuleCreationError(Exception):
"""Exeception to raise for any errors with invalid rules"""
"""Exception to raise for any errors with invalid rules"""


def rule(**opts):
Expand Down Expand Up @@ -242,7 +242,7 @@ def rules_for_log_type(cls, log_type):


class MatcherCreationError(Exception):
"""Exeception to raise for any errors with invalid matchers"""
"""Exception to raise for any errors with invalid matchers"""


def matcher(matcher_func):
Expand Down
27 changes: 14 additions & 13 deletions tests/unit/stream_alert_rule_promotion/test_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import boto3
from mock import Mock, patch, PropertyMock
from moto import mock_ssm
from nose.tools import assert_equal
from nose.tools import assert_equal, assert_raises

from stream_alert.rule_promotion.publisher import StatsPublisher
from stream_alert.rule_promotion.statistic import StagingStatistic
from stream_alert.shared import config
from stream_alert.shared import athena, config


class TestStatsPublisher(object):
Expand Down Expand Up @@ -147,24 +147,25 @@ def test_write_state(self, ssm_mock):
self.publisher._write_state()
ssm_mock.put_parameter.assert_called_with(**args)

def test_query_alerts_none(self):
@patch('stream_alert.rule_promotion.publisher.StatsPublisher._publish_message')
@patch('stream_alert.rule_promotion.publisher.StatsPublisher._write_state', Mock())
def test_query_alerts_none(self, publish_mock):
"""StatsPublisher - Query Alerts, No Alerts for Stat"""
stat = list(self._get_fake_stats(count=1))[0]
stat.alert_count = 0
self.publisher._state['send_digest_hour_utc'] = 1
stats = list(self._get_fake_stats(count=1))
stats[0].alert_count = 0
with patch.object(self.publisher, '_athena_client', new_callable=PropertyMock) as mock:
assert_equal(self.publisher._query_alerts(stat), None)
self.publisher.publish(stats)
assert_equal(stats[0].execution_id, None)
mock.run_async_query.assert_not_called()
publish_mock.assert_called_with(stats)

@patch('logging.Logger.error')
def test_query_alerts_bad_reponse(self, log_mock):
def test_query_alerts_bad_reponse(self):
"""StatsPublisher - Query Alerts, Bad Response"""
stat = list(self._get_fake_stats(count=1))[0]
with patch.object(self.publisher, '_athena_client', new_callable=PropertyMock) as mock:
mock.run_async_query.return_value = None
assert_equal(self.publisher._query_alerts(stat), None)
mock.run_async_query.assert_called_once()
log_mock.assert_called_with(
'Failed to query alert info for rule: \'%s\'', 'test_rule_0')
mock.run_async_query.side_effect = athena.AthenaQueryExecutionError()
assert_raises(athena.AthenaQueryExecutionError, self.publisher._query_alerts, stat)

def test_query_alerts(self):
"""StatsPublisher - Query Alerts"""
Expand Down
30 changes: 14 additions & 16 deletions tests/unit/stream_alert_shared/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
assert_equal,
assert_false,
assert_items_equal,
assert_true
assert_raises,
assert_true,
raises
)

from stream_alert.shared.athena import AthenaClient
from stream_alert.shared.athena import AthenaClient, AthenaQueryExecutionError
from stream_alert.shared.config import load_config

from tests.unit.helpers.aws_mocks import MockAthenaClient
Expand Down Expand Up @@ -122,8 +124,7 @@ def test_get_table_partitions(self):
def test_get_table_partitions_error(self):
"""Athena - Get Table Partitions, Exception"""
self.client._client.raise_exception = True
result = self.client.get_table_partitions('test_table')
assert_equal(result, None)
assert_raises(AthenaQueryExecutionError, self.client.get_table_partitions, 'test_table')

def test_drop_table(self):
"""Athena - Drop Table, Success"""
Expand All @@ -132,7 +133,7 @@ def test_drop_table(self):
def test_drop_table_failure(self):
"""Athena - Drop Table, Failure"""
self.client._client.raise_exception = True
assert_false(self.client.drop_table('test_table'))
assert_raises(AthenaQueryExecutionError, self.client.drop_table, 'test_table')

@patch('stream_alert.shared.athena.AthenaClient.drop_table')
def test_drop_all_tables(self, drop_table_mock):
Expand All @@ -159,14 +160,12 @@ def test_drop_all_tables_failure(self, drop_table_mock):
def test_drop_all_tables_exception(self):
"""Athena - Drop All Tables, Exception"""
self.client._client.raise_exception = True
assert_false(self.client.drop_all_tables())
assert_raises(AthenaQueryExecutionError, self.client.drop_all_tables)

@patch('logging.Logger.exception')
def test_execute_query(self, log_mock):
def test_execute_query(self):
"""Athena - Execute Query"""
self.client._client.raise_exception = True
self.client._execute_query('BAD SQL')
log_mock.assert_called_with('Athena query failed')
assert_raises(AthenaQueryExecutionError, self.client._execute_query, 'BAD SQL')

def test_execute_and_wait(self):
"""Athena - Execute and Wait"""
Expand All @@ -176,13 +175,11 @@ def test_execute_and_wait(self):
result = self.client._execute_and_wait('SQL query')
assert_true(result in self.client._client.query_executions)

@patch('logging.Logger.error')
def test_execute_and_wait_failed(self, log_mock):
def test_execute_and_wait_failed(self):
"""Athena - Execute and Wait, Failed"""
query = 'SQL query'
self.client._client.result_state = 'FAILED'
self.client._execute_and_wait(query)
log_mock.assert_called_with('Athena query failed:\n%s', query)
assert_raises(AthenaQueryExecutionError, self.client._execute_and_wait, query)

def test_query_result_paginator(self):
"""Athena - Query Result Paginator"""
Expand All @@ -194,10 +191,11 @@ def test_query_result_paginator(self):
items = list(self.client.query_result_paginator('test query'))
assert_items_equal(items, [{'ResultSet': {'Rows': [data]}}] * 4)

@raises(AthenaQueryExecutionError)
def test_query_result_paginator_error(self):
"""Athena - Query Result Paginator, Exception"""
self.client._client.raise_exception = True
assert_equal(list(self.client.query_result_paginator('test query')), list())
list(self.client.query_result_paginator('test query'))

def test_run_async_query(self):
"""Athena - Run Async Query, Success"""
Expand All @@ -206,4 +204,4 @@ def test_run_async_query(self):
def test_run_async_query_failure(self):
"""Athena - Run Async Query, Failure"""
self.client._client.raise_exception = True
assert_false(self.client.run_async_query('test query'))
assert_raises(AthenaQueryExecutionError, self.client.run_async_query, 'test query')
2 changes: 1 addition & 1 deletion tests/unit/stream_alert_shared/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_rule(_):

@patch('logging.Logger.exception')
def test_rule_process_exception(self, log_mock):
"""Rule - Process, Exeception"""
"""Rule - Process, Exception"""
# Create a rule function that will raise an exception
def test_rule(_):
raise ValueError('this is a bad rule')
Expand Down

0 comments on commit d1b6b7a

Please sign in to comment.