From d1b6b7a3c5b4aa720cb23c6155c1a553395e2640 Mon Sep 17 00:00:00 2001 From: ryandeivert Date: Mon, 2 Jul 2018 11:24:04 -0700 Subject: [PATCH] updates to raise appropriate errors from athena client, etc (#779) --- stream_alert/rule_promotion/promoter.py | 2 + stream_alert/rule_promotion/publisher.py | 11 ++- stream_alert/rule_promotion/statistic.py | 10 ++- stream_alert/shared/athena.py | 68 +++++++++++++------ stream_alert/shared/rule.py | 4 +- .../test_publisher.py | 27 ++++---- tests/unit/stream_alert_shared/test_athena.py | 30 ++++---- tests/unit/stream_alert_shared/test_rule.py | 2 +- 8 files changed, 93 insertions(+), 61 deletions(-) diff --git a/stream_alert/rule_promotion/promoter.py b/stream_alert/rule_promotion/promoter.py index b98b408c5..4820af3bc 100644 --- a/stream_alert/rule_promotion/promoter.py +++ b/stream_alert/rule_promotion/promoter.py @@ -104,6 +104,8 @@ def _update_alert_count(self): row_values = [data.values()[0] for data in row['Data']] rule_name, alert_count = row_values[0], int(row_values[1]) + LOGGER.debug('Found %d alerts for rule \'%s\'', alert_count, rule_name) + self._staging_stats[rule_name].alert_count = alert_count def run(self): diff --git a/stream_alert/rule_promotion/publisher.py b/stream_alert/rule_promotion/publisher.py index 739e2c23b..8d07c29ae 100644 --- a/stream_alert/rule_promotion/publisher.py +++ b/stream_alert/rule_promotion/publisher.py @@ -103,17 +103,10 @@ def _query_alerts(self, stat): Returns: str: Execution ID for running Athena query """ - # If there are no alerts, do not run the comprehensive query - if not stat.alert_count: - return - info_statement = stat.sql_info_statement LOGGER.debug('Querying alert info for rule \'%s\': %s', stat.rule_name, info_statement) response = self._athena_client.run_async_query(info_statement) - if not response: - LOGGER.error('Failed to query alert info for rule: \'%s\'', stat.rule_name) - return return response['QueryExecutionId'] @@ -156,6 +149,10 @@ def publish(self, stats): if self._should_send_digest: for stat in stats: + # If there are no alerts, do not run the comprehensive query + if not stat: + continue + stat.execution_id = self._query_alerts(stat) self._publish_message(stats) diff --git a/stream_alert/rule_promotion/statistic.py b/stream_alert/rule_promotion/statistic.py index f0d40169b..824452732 100644 --- a/stream_alert/rule_promotion/statistic.py +++ b/stream_alert/rule_promotion/statistic.py @@ -18,6 +18,8 @@ class StagingStatistic(object): """Store information on generated alerts.""" + _ALERT_COUNT_UNKOWN = 'unknown' + _COUNT_QUERY_TEMPLATE = ("SELECT '{rule_name}' AS rule_name, count(*) AS count " "FROM alerts WHERE dt >= '{date}-{hour:02}' AND " "rule_name = '{rule_name}'") @@ -34,9 +36,15 @@ def __init__(self, staged_at, staged_until, current_time, rule): self._current_time = current_time self._staged_at = staged_at self.staged_until = staged_until - self.alert_count = 'unknown' + self.alert_count = self._ALERT_COUNT_UNKOWN self.execution_id = None + def __nonzero__(self): + return self.alert_count not in {0, self._ALERT_COUNT_UNKOWN} + + # For forward compatibility to Python3 + __bool__ = __nonzero__ + def __lt__(self, other): """Statistic should be ordered by their alert count.""" return self.alert_count < other.alert_count diff --git a/stream_alert/shared/athena.py b/stream_alert/shared/athena.py index 27185419f..44f3a0f81 100644 --- a/stream_alert/shared/athena.py +++ b/stream_alert/shared/athena.py @@ -27,6 +27,10 @@ ) +class AthenaQueryExecutionError(Exception): + """Exception to be raised when an Athena query fails""" + + class AthenaClient(object): """A StreamAlert Athena Client for creating tables, databases, and executing queries @@ -84,17 +88,18 @@ def _execute_and_wait(self, query): Returns: str: Athena execution ID for the query that was executed + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ response = self._execute_query(query) - if not response: - return exeuction_id = response['QueryExecutionId'] - success = self.check_query_status(exeuction_id) - if not success: - LOGGER.error('Athena query failed:\n%s', query) - return + # This will block until the execution is complete, or raise an + # AthenaQueryExecutionError exception if an error occurs + self.check_query_status(exeuction_id) return exeuction_id @@ -106,6 +111,10 @@ def _execute_query(self, query): Returns: dict: Response object with the status of the running query + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ LOGGER.debug('Executing query: %s', query) try: @@ -114,8 +123,8 @@ def _execute_query(self, query): QueryExecutionContext={'Database': self.database}, ResultConfiguration={'OutputLocation': self._s3_results_path} ) - except ClientError: - LOGGER.exception('Athena query failed') + except ClientError as err: + raise AthenaQueryExecutionError('Athena query failed:\n{}'.format(err)) def drop_all_tables(self): """Drop all table in the database @@ -165,7 +174,7 @@ def get_table_partitions(self, table_name): return self._unique_values_from_query(partitions) - def check_query_status(self, query_execution_id): + def check_query_status(self, execution_id): """Check in on the running query, back off if the job is running or queued Args: @@ -174,8 +183,12 @@ def check_query_status(self, query_execution_id): Returns: bool: True if the query state is SUCCEEDED, False otherwise Reference https://bit.ly/2uuRtda. + + Raises: + AthenaQueryExecutionError: If any failure occurs while checking the status of the + query, this exception will be raised """ - LOGGER.debug('Checking status of query with execution ID: %s', query_execution_id) + LOGGER.debug('Checking status of query with execution ID: %s', execution_id) states_to_backoff = {'QUEUED', 'RUNNING'} @backoff.on_predicate(backoff.fibo, @@ -190,14 +203,16 @@ def _check_status(query_execution_id): QueryExecutionId=query_execution_id ) - execution_result = _check_status(query_execution_id) + execution_result = _check_status(execution_id) state = execution_result['QueryExecution']['Status']['State'] - if state != 'SUCCEEDED': - reason = execution_result['QueryExecution']['Status']['StateChangeReason'] - LOGGER.error('Query %s %s with reason %s, exiting!', query_execution_id, state, reason) - return False + if state == 'SUCCEEDED': + return - return True + # When the state is not SUCCEEDED, something bad must have occurred, so raise an exception + reason = execution_result['QueryExecution']['Status']['StateChangeReason'] + raise AthenaQueryExecutionError( + 'Query \'{}\' {} with reason \'{}\', exiting'.format(execution_id, state, reason) + ) def query_result_paginator(self, query): """Iterate over all results returned by the Athena query. This is a blocking operation @@ -207,10 +222,12 @@ def query_result_paginator(self, query): Yields: dict: Response objects with the results of the running query + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ execution_id = self._execute_and_wait(query) - if not execution_id: - return paginator = self._client.get_paginator('get_query_results') @@ -226,6 +243,10 @@ def run_async_query(self, query): Returns: dict: Response object with the status of the running query + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ return self._execute_query(query) @@ -237,6 +258,10 @@ def run_query(self, query): Returns: bool: True if the query ran successfully, False otherwise + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ return bool(self._execute_and_wait(query)) @@ -248,11 +273,12 @@ def run_query_for_results(self, query): Returns: dict: Response object with the result of the running query + + Raises: + AthenaQueryExecutionError: If any failure occurs during the execution of the + query, this exception will be raised """ execution_id = self._execute_and_wait(query) - if not execution_id: - return - query_results = self._client.get_query_results(QueryExecutionId=execution_id) # The idea here is to leave the processing logic to the calling functions. diff --git a/stream_alert/shared/rule.py b/stream_alert/shared/rule.py index 4c1a51b04..4a8d981d0 100644 --- a/stream_alert/shared/rule.py +++ b/stream_alert/shared/rule.py @@ -69,7 +69,7 @@ def import_folders(*paths): class RuleCreationError(Exception): - """Exeception to raise for any errors with invalid rules""" + """Exception to raise for any errors with invalid rules""" def rule(**opts): @@ -242,7 +242,7 @@ def rules_for_log_type(cls, log_type): class MatcherCreationError(Exception): - """Exeception to raise for any errors with invalid matchers""" + """Exception to raise for any errors with invalid matchers""" def matcher(matcher_func): diff --git a/tests/unit/stream_alert_rule_promotion/test_publisher.py b/tests/unit/stream_alert_rule_promotion/test_publisher.py index 676e01c42..74db9464f 100644 --- a/tests/unit/stream_alert_rule_promotion/test_publisher.py +++ b/tests/unit/stream_alert_rule_promotion/test_publisher.py @@ -20,11 +20,11 @@ import boto3 from mock import Mock, patch, PropertyMock from moto import mock_ssm -from nose.tools import assert_equal +from nose.tools import assert_equal, assert_raises from stream_alert.rule_promotion.publisher import StatsPublisher from stream_alert.rule_promotion.statistic import StagingStatistic -from stream_alert.shared import config +from stream_alert.shared import athena, config class TestStatsPublisher(object): @@ -147,24 +147,25 @@ def test_write_state(self, ssm_mock): self.publisher._write_state() ssm_mock.put_parameter.assert_called_with(**args) - def test_query_alerts_none(self): + @patch('stream_alert.rule_promotion.publisher.StatsPublisher._publish_message') + @patch('stream_alert.rule_promotion.publisher.StatsPublisher._write_state', Mock()) + def test_query_alerts_none(self, publish_mock): """StatsPublisher - Query Alerts, No Alerts for Stat""" - stat = list(self._get_fake_stats(count=1))[0] - stat.alert_count = 0 + self.publisher._state['send_digest_hour_utc'] = 1 + stats = list(self._get_fake_stats(count=1)) + stats[0].alert_count = 0 with patch.object(self.publisher, '_athena_client', new_callable=PropertyMock) as mock: - assert_equal(self.publisher._query_alerts(stat), None) + self.publisher.publish(stats) + assert_equal(stats[0].execution_id, None) mock.run_async_query.assert_not_called() + publish_mock.assert_called_with(stats) - @patch('logging.Logger.error') - def test_query_alerts_bad_reponse(self, log_mock): + def test_query_alerts_bad_reponse(self): """StatsPublisher - Query Alerts, Bad Response""" stat = list(self._get_fake_stats(count=1))[0] with patch.object(self.publisher, '_athena_client', new_callable=PropertyMock) as mock: - mock.run_async_query.return_value = None - assert_equal(self.publisher._query_alerts(stat), None) - mock.run_async_query.assert_called_once() - log_mock.assert_called_with( - 'Failed to query alert info for rule: \'%s\'', 'test_rule_0') + mock.run_async_query.side_effect = athena.AthenaQueryExecutionError() + assert_raises(athena.AthenaQueryExecutionError, self.publisher._query_alerts, stat) def test_query_alerts(self): """StatsPublisher - Query Alerts""" diff --git a/tests/unit/stream_alert_shared/test_athena.py b/tests/unit/stream_alert_shared/test_athena.py index 65e98963c..e9f4ddad1 100644 --- a/tests/unit/stream_alert_shared/test_athena.py +++ b/tests/unit/stream_alert_shared/test_athena.py @@ -22,10 +22,12 @@ assert_equal, assert_false, assert_items_equal, - assert_true + assert_raises, + assert_true, + raises ) -from stream_alert.shared.athena import AthenaClient +from stream_alert.shared.athena import AthenaClient, AthenaQueryExecutionError from stream_alert.shared.config import load_config from tests.unit.helpers.aws_mocks import MockAthenaClient @@ -122,8 +124,7 @@ def test_get_table_partitions(self): def test_get_table_partitions_error(self): """Athena - Get Table Partitions, Exception""" self.client._client.raise_exception = True - result = self.client.get_table_partitions('test_table') - assert_equal(result, None) + assert_raises(AthenaQueryExecutionError, self.client.get_table_partitions, 'test_table') def test_drop_table(self): """Athena - Drop Table, Success""" @@ -132,7 +133,7 @@ def test_drop_table(self): def test_drop_table_failure(self): """Athena - Drop Table, Failure""" self.client._client.raise_exception = True - assert_false(self.client.drop_table('test_table')) + assert_raises(AthenaQueryExecutionError, self.client.drop_table, 'test_table') @patch('stream_alert.shared.athena.AthenaClient.drop_table') def test_drop_all_tables(self, drop_table_mock): @@ -159,14 +160,12 @@ def test_drop_all_tables_failure(self, drop_table_mock): def test_drop_all_tables_exception(self): """Athena - Drop All Tables, Exception""" self.client._client.raise_exception = True - assert_false(self.client.drop_all_tables()) + assert_raises(AthenaQueryExecutionError, self.client.drop_all_tables) - @patch('logging.Logger.exception') - def test_execute_query(self, log_mock): + def test_execute_query(self): """Athena - Execute Query""" self.client._client.raise_exception = True - self.client._execute_query('BAD SQL') - log_mock.assert_called_with('Athena query failed') + assert_raises(AthenaQueryExecutionError, self.client._execute_query, 'BAD SQL') def test_execute_and_wait(self): """Athena - Execute and Wait""" @@ -176,13 +175,11 @@ def test_execute_and_wait(self): result = self.client._execute_and_wait('SQL query') assert_true(result in self.client._client.query_executions) - @patch('logging.Logger.error') - def test_execute_and_wait_failed(self, log_mock): + def test_execute_and_wait_failed(self): """Athena - Execute and Wait, Failed""" query = 'SQL query' self.client._client.result_state = 'FAILED' - self.client._execute_and_wait(query) - log_mock.assert_called_with('Athena query failed:\n%s', query) + assert_raises(AthenaQueryExecutionError, self.client._execute_and_wait, query) def test_query_result_paginator(self): """Athena - Query Result Paginator""" @@ -194,10 +191,11 @@ def test_query_result_paginator(self): items = list(self.client.query_result_paginator('test query')) assert_items_equal(items, [{'ResultSet': {'Rows': [data]}}] * 4) + @raises(AthenaQueryExecutionError) def test_query_result_paginator_error(self): """Athena - Query Result Paginator, Exception""" self.client._client.raise_exception = True - assert_equal(list(self.client.query_result_paginator('test query')), list()) + list(self.client.query_result_paginator('test query')) def test_run_async_query(self): """Athena - Run Async Query, Success""" @@ -206,4 +204,4 @@ def test_run_async_query(self): def test_run_async_query_failure(self): """Athena - Run Async Query, Failure""" self.client._client.raise_exception = True - assert_false(self.client.run_async_query('test query')) + assert_raises(AthenaQueryExecutionError, self.client.run_async_query, 'test query') diff --git a/tests/unit/stream_alert_shared/test_rule.py b/tests/unit/stream_alert_shared/test_rule.py index f0fee7e78..9dee4fc25 100644 --- a/tests/unit/stream_alert_shared/test_rule.py +++ b/tests/unit/stream_alert_shared/test_rule.py @@ -100,7 +100,7 @@ def test_rule(_): @patch('logging.Logger.exception') def test_rule_process_exception(self, log_mock): - """Rule - Process, Exeception""" + """Rule - Process, Exception""" # Create a rule function that will raise an exception def test_rule(_): raise ValueError('this is a bad rule')