From 65e7025f3af378dfa825eb04d005ec1f7a422cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Rodr=C3=ADguez=20Cano?= Date: Wed, 11 Apr 2018 11:57:21 +0200 Subject: [PATCH] [AIRFLOW-1774] Allow consistent templating of arguments in MLEngineBatchPredictionOperator Fix a minor typo and a wrong non-default assignment Fix one more typo Adapt tests to new error messages and fix another typo Fix exception type in utils operator test class Improve cleansing of non-valid training and prediciton job names Closes #2746 from wileeam/ml-engine-prediction- job-normalization --- .../contrib/operators/mlengine_operator.py | 205 ++++++++++-------- .../operators/test_mlengine_operator.py | 88 ++++---- .../operators/test_mlengine_operator_utils.py | 6 +- 3 files changed, 156 insertions(+), 143 deletions(-) diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py index 0d033d35178f9..3dd63f2c116bc 100644 --- a/airflow/contrib/operators/mlengine_operator.py +++ b/airflow/contrib/operators/mlengine_operator.py @@ -15,77 +15,17 @@ # limitations under the License. import re -from airflow import settings +from apiclient import errors + from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook from airflow.exceptions import AirflowException from airflow.operators import BaseOperator from airflow.utils.decorators import apply_defaults -from apiclient import errors - from airflow.utils.log.logging_mixin import LoggingMixin log = LoggingMixin().log -def _create_prediction_input(project_id, - region, - data_format, - input_paths, - output_path, - model_name=None, - version_name=None, - uri=None, - max_worker_count=None, - runtime_version=None): - """ - Create the batch prediction input from the given parameters. - - Args: - A subset of arguments documented in __init__ method of class - MLEngineBatchPredictionOperator - - Returns: - A dictionary representing the predictionInput object as documented - in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs. - - Raises: - ValueError: if a unique model/version origin cannot be determined. - """ - prediction_input = { - 'dataFormat': data_format, - 'inputPaths': input_paths, - 'outputPath': output_path, - 'region': region - } - - if uri: - if model_name or version_name: - log.error( - 'Ambiguous model origin: Both uri and model/version name are provided.' - ) - raise ValueError('Ambiguous model origin.') - prediction_input['uri'] = uri - elif model_name: - origin_name = 'projects/{}/models/{}'.format(project_id, model_name) - if not version_name: - prediction_input['modelName'] = origin_name - else: - prediction_input['versionName'] = \ - origin_name + '/versions/{}'.format(version_name) - else: - log.error( - 'Missing model origin: Batch prediction expects a model, ' - 'a model & version combination, or a URI to savedModel.') - raise ValueError('Missing model origin.') - - if max_worker_count: - prediction_input['maxWorkerCount'] = max_worker_count - if runtime_version: - prediction_input['runtimeVersion'] = runtime_version - - return prediction_input - - def _normalize_mlengine_job_id(job_id): """ Replaces invalid MLEngine job_id characters with '_'. @@ -99,10 +39,27 @@ def _normalize_mlengine_job_id(job_id): Returns: A valid job_id representation. """ - match = re.search(r'\d', job_id) + + # Add a prefix when a job_id starts with a digit or a template + match = re.search(r'\d|\{{2}', job_id) if match and match.start() is 0: - job_id = 'z_{}'.format(job_id) - return re.sub('[^0-9a-zA-Z]+', '_', job_id) + job = 'z_{}'.format(job_id) + else: + job = job_id + + # Clean up 'bad' characters except templates + tracker = 0 + cleansed_job_id = '' + for m in re.finditer(r'\{{2}.+?\}{2}', job): + cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', + job[tracker:m.start()]) + cleansed_job_id += job[m.start():m.end()] + tracker = m.end() + + # Clean up last substring or the full string if no templates + cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker:]) + + return cleansed_job_id class MLEngineBatchPredictionOperator(BaseOperator): @@ -132,6 +89,8 @@ class MLEngineBatchPredictionOperator(BaseOperator): if the desired model version is "projects/my_project/models/my_model/versions/my_version". + See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs + for further documentation on the parameters. :param project_id: The Google Cloud project name where the prediction job is submitted. @@ -197,7 +156,14 @@ class MLEngineBatchPredictionOperator(BaseOperator): """ template_fields = [ - "prediction_job_request", + '_project_id', + '_job_id', + '_region', + '_input_paths', + '_output_path', + '_model_name', + '_version_name', + '_uri', ] @apply_defaults @@ -219,45 +185,91 @@ def __init__(self, **kwargs): super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to + self._project_id = project_id + self._job_id = job_id + self._region = region + self._data_format = data_format + self._input_paths = input_paths + self._output_path = output_path + self._model_name = model_name + self._version_name = version_name + self._uri = uri + self._max_worker_count = max_worker_count + self._runtime_version = runtime_version + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to - try: - prediction_input = _create_prediction_input( - project_id, region, data_format, input_paths, output_path, - model_name, version_name, uri, max_worker_count, - runtime_version) - except ValueError as e: - self.log.error( - 'Cannot create batch prediction job request due to: %s', - e - ) - raise + if not self._project_id: + raise AirflowException('Google Cloud project id is required.') + if not self._job_id: + raise AirflowException( + 'An unique job id is required for Google MLEngine prediction ' + 'job.') - self.prediction_job_request = { - 'jobId': _normalize_mlengine_job_id(job_id), - 'predictionInput': prediction_input - } + if self._uri: + if self._model_name or self._version_name: + raise AirflowException('Ambiguous model origin: Both uri and ' + 'model/version name are provided.') + + if self._version_name and not self._model_name: + raise AirflowException( + 'Missing model: Batch prediction expects ' + 'a model name when a version name is provided.') + + if not (self._uri or self._model_name): + raise AirflowException( + 'Missing model origin: Batch prediction expects a model, ' + 'a model & version combination, or a URI to a savedModel.') def execute(self, context): - hook = MLEngineHook(self.gcp_conn_id, self.delegate_to) + job_id = _normalize_mlengine_job_id(self._job_id) + prediction_request = { + 'jobId': job_id, + 'predictionInput': { + 'dataFormat': self._data_format, + 'inputPaths': self._input_paths, + 'outputPath': self._output_path, + 'region': self._region + } + } + if self._uri: + prediction_request['predictionInput']['uri'] = self._uri + elif self._model_name: + origin_name = 'projects/{}/models/{}'.format( + self._project_id, self._model_name) + if not self._version_name: + prediction_request['predictionInput'][ + 'modelName'] = origin_name + else: + prediction_request['predictionInput']['versionName'] = \ + origin_name + '/versions/{}'.format(self._version_name) + + if self._max_worker_count: + prediction_request['predictionInput'][ + 'maxWorkerCount'] = self._max_worker_count + + if self._runtime_version: + prediction_request['predictionInput'][ + 'runtimeVersion'] = self._runtime_version + + hook = MLEngineHook(self._gcp_conn_id, self._delegate_to) + + # Helper method to check if the existing job's prediction input is the + # same as the request we get here. def check_existing_job(existing_job): return existing_job.get('predictionInput', None) == \ - self.prediction_job_request['predictionInput'] + prediction_request['predictionInput'] + try: finished_prediction_job = hook.create_job( - self.project_id, - self.prediction_job_request, - check_existing_job) + self._project_id, prediction_request, check_existing_job) except errors.HttpError: raise if finished_prediction_job['state'] != 'SUCCEEDED': - self.log.error( - 'Batch prediction job failed: %s', - str(finished_prediction_job)) + self.log.error('MLEngine batch prediction job failed: {}'.format( + str(finished_prediction_job))) raise RuntimeError(finished_prediction_job['errorMessage']) return finished_prediction_job['predictionOutput'] @@ -419,9 +431,8 @@ def execute(self, context): return hook.create_version(self._project_id, self._model_name, self._version) elif self._operation == 'set_default': - return hook.set_default_version( - self._project_id, self._model_name, - self._version['name']) + return hook.set_default_version(self._project_id, self._model_name, + self._version['name']) elif self._operation == 'list': return hook.list_versions(self._project_id, self._model_name) elif self._operation == 'delete': @@ -546,7 +557,8 @@ def execute(self, context): if self._mode == 'DRY_RUN': self.log.info('In dry_run mode.') - self.log.info('MLEngine Training job request is: {}'.format(training_request)) + self.log.info('MLEngine Training job request is: {}'.format( + training_request)) return hook = MLEngineHook( @@ -557,6 +569,7 @@ def execute(self, context): def check_existing_job(existing_job): return existing_job.get('trainingInput', None) == \ training_request['trainingInput'] + try: finished_training_job = hook.create_job( self._project_id, training_request, check_existing_job) diff --git a/tests/contrib/operators/test_mlengine_operator.py b/tests/contrib/operators/test_mlengine_operator.py index 75b46a0c44d29..2766e5d767658 100644 --- a/tests/contrib/operators/test_mlengine_operator.py +++ b/tests/contrib/operators/test_mlengine_operator.py @@ -15,21 +15,19 @@ # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import datetime -from apiclient import errors -import httplib2 import unittest -from airflow import configuration, DAG -from airflow.contrib.operators.mlengine_operator import MLEngineBatchPredictionOperator -from airflow.contrib.operators.mlengine_operator import MLEngineTrainingOperator +import httplib2 +from apiclient import errors +from mock import ANY, patch -from mock import ANY -from mock import patch +from airflow import DAG, configuration +from airflow.contrib.operators.mlengine_operator import (MLEngineBatchPredictionOperator, + MLEngineTrainingOperator) +from airflow.exceptions import AirflowException DEFAULT_DATE = datetime.datetime(2017, 6, 6) @@ -58,7 +56,7 @@ class MLEngineBatchPredictionOperatorTest(unittest.TestCase): 'data_format': 'TEXT', 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'], 'output_path': - 'gs://12_legal_bucket_underscore_number/legal-output-path', + 'gs://12_legal_bucket_underscore_number/legal-output-path', 'task_id': 'test-prediction' } @@ -105,14 +103,12 @@ def testSuccessWithModel(self): mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( - 'test-project', - { + 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) + self.assertEquals(success_message['predictionOutput'], + prediction_output) def testSuccessWithVersion(self): with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \ @@ -132,7 +128,8 @@ def testSuccessWithVersion(self): hook_instance.create_job.return_value = success_message prediction_task = MLEngineBatchPredictionOperator( - job_id='test_prediction', project_id='test-project', + job_id='test_prediction', + project_id='test-project', region=input_with_version['region'], data_format=input_with_version['dataFormat'], input_paths=input_with_version['inputPaths'], @@ -145,14 +142,12 @@ def testSuccessWithVersion(self): mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( - 'test-project', - { + 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_version }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) + self.assertEquals(success_message['predictionOutput'], + prediction_output) def testSuccessWithURI(self): with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \ @@ -184,48 +179,51 @@ def testSuccessWithURI(self): mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( - 'test-project', - { + 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_uri }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) + self.assertEquals(success_message['predictionOutput'], + prediction_output) def testInvalidModelOrigin(self): # Test that both uri and model is given task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'gs://fake-uri/saved_model' task_args['model_name'] = 'fake_model' - with self.assertRaises(ValueError) as context: + with self.assertRaises(AirflowException) as context: MLEngineBatchPredictionOperator(**task_args).execute(None) - self.assertEquals('Ambiguous model origin.', str(context.exception)) + self.assertEquals('Ambiguous model origin: Both uri and ' + 'model/version name are provided.', + str(context.exception)) # Test that both uri and model/version is given task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'gs://fake-uri/saved_model' task_args['model_name'] = 'fake_model' task_args['version_name'] = 'fake_version' - with self.assertRaises(ValueError) as context: + with self.assertRaises(AirflowException) as context: MLEngineBatchPredictionOperator(**task_args).execute(None) - self.assertEquals('Ambiguous model origin.', str(context.exception)) + self.assertEquals('Ambiguous model origin: Both uri and ' + 'model/version name are provided.', + str(context.exception)) # Test that a version is given without a model task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['version_name'] = 'bare_version' - with self.assertRaises(ValueError) as context: + with self.assertRaises(AirflowException) as context: MLEngineBatchPredictionOperator(**task_args).execute(None) - self.assertEquals( - 'Missing model origin.', - str(context.exception)) + self.assertEquals('Missing model: Batch prediction expects a model ' + 'name when a version name is provided.', + str(context.exception)) # Test that none of uri, model, model/version is given task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - with self.assertRaises(ValueError) as context: + with self.assertRaises(AirflowException) as context: MLEngineBatchPredictionOperator(**task_args).execute(None) self.assertEquals( - 'Missing model origin.', + 'Missing model origin: Batch prediction expects a ' + 'model, a model & version combination, or a URI to a savedModel.', str(context.exception)) def testHttpError(self): @@ -241,7 +239,8 @@ def testHttpError(self): hook_instance.create_job.side_effect = errors.HttpError( resp=httplib2.Response({ 'status': http_error_code - }), content=b'Forbidden') + }), + content=b'Forbidden') with self.assertRaises(errors.HttpError) as context: prediction_task = MLEngineBatchPredictionOperator( @@ -258,8 +257,7 @@ def testHttpError(self): mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( - 'test-project', - { + 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) @@ -313,11 +311,12 @@ def testSuccessCreateTrainingJob(self): hook_instance = mock_hook.return_value hook_instance.create_job.return_value = success_response - training_op = MLEngineTrainingOperator(**self.TRAINING_DEFAULT_ARGS) + training_op = MLEngineTrainingOperator( + **self.TRAINING_DEFAULT_ARGS) training_op.execute(None) - mock_hook.assert_called_with(gcp_conn_id='google_cloud_default', - delegate_to=None) + mock_hook.assert_called_with( + gcp_conn_id='google_cloud_default', delegate_to=None) # Make sure only 'create_job' is invoked on hook instance self.assertEquals(len(hook_instance.mock_calls), 1) hook_instance.create_job.assert_called_with( @@ -331,7 +330,8 @@ def testHttpError(self): hook_instance.create_job.side_effect = errors.HttpError( resp=httplib2.Response({ 'status': http_error_code - }), content=b'Forbidden') + }), + content=b'Forbidden') with self.assertRaises(errors.HttpError) as context: training_op = MLEngineTrainingOperator( diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py b/tests/contrib/operators/test_mlengine_operator_utils.py index c8f6fb5544f64..0cb106da6bb8d 100644 --- a/tests/contrib/operators/test_mlengine_operator_utils.py +++ b/tests/contrib/operators/test_mlengine_operator_utils.py @@ -158,14 +158,14 @@ def testFailures(self): 'dag': dag, } - with self.assertRaisesRegexp(ValueError, 'Missing model origin'): + with self.assertRaisesRegexp(AirflowException, 'Missing model origin'): _ = create_evaluate_ops(**other_params_but_models) - with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): + with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'): _ = create_evaluate_ops(model_uri='abc', model_name='cde', **other_params_but_models) - with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): + with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'): _ = create_evaluate_ops(model_uri='abc', version_name='vvv', **other_params_but_models)