Skip to content

Commit

Permalink
[AIRFLOW-1774] Allow consistent templating of arguments in MLEngineBa…
Browse files Browse the repository at this point in the history
…tchPredictionOperator

Fix a minor typo and a wrong non-default
assignment

Fix one more typo

Adapt tests to new error messages and fix another
typo

Fix exception type in utils operator test class

Improve cleansing of non-valid training and
prediciton job names

Closes apache#2746 from wileeam/ml-engine-prediction-
job-normalization
  • Loading branch information
wileeam authored and Fokko Driesprong committed Apr 11, 2018
1 parent 3475faf commit 65e7025
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 143 deletions.
205 changes: 109 additions & 96 deletions airflow/contrib/operators/mlengine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,77 +15,17 @@
# limitations under the License.
import re

from airflow import settings
from apiclient import errors

from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook
from airflow.exceptions import AirflowException
from airflow.operators import BaseOperator
from airflow.utils.decorators import apply_defaults
from apiclient import errors

from airflow.utils.log.logging_mixin import LoggingMixin

log = LoggingMixin().log


def _create_prediction_input(project_id,
region,
data_format,
input_paths,
output_path,
model_name=None,
version_name=None,
uri=None,
max_worker_count=None,
runtime_version=None):
"""
Create the batch prediction input from the given parameters.
Args:
A subset of arguments documented in __init__ method of class
MLEngineBatchPredictionOperator
Returns:
A dictionary representing the predictionInput object as documented
in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
Raises:
ValueError: if a unique model/version origin cannot be determined.
"""
prediction_input = {
'dataFormat': data_format,
'inputPaths': input_paths,
'outputPath': output_path,
'region': region
}

if uri:
if model_name or version_name:
log.error(
'Ambiguous model origin: Both uri and model/version name are provided.'
)
raise ValueError('Ambiguous model origin.')
prediction_input['uri'] = uri
elif model_name:
origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
if not version_name:
prediction_input['modelName'] = origin_name
else:
prediction_input['versionName'] = \
origin_name + '/versions/{}'.format(version_name)
else:
log.error(
'Missing model origin: Batch prediction expects a model, '
'a model & version combination, or a URI to savedModel.')
raise ValueError('Missing model origin.')

if max_worker_count:
prediction_input['maxWorkerCount'] = max_worker_count
if runtime_version:
prediction_input['runtimeVersion'] = runtime_version

return prediction_input


def _normalize_mlengine_job_id(job_id):
"""
Replaces invalid MLEngine job_id characters with '_'.
Expand All @@ -99,10 +39,27 @@ def _normalize_mlengine_job_id(job_id):
Returns:
A valid job_id representation.
"""
match = re.search(r'\d', job_id)

# Add a prefix when a job_id starts with a digit or a template
match = re.search(r'\d|\{{2}', job_id)
if match and match.start() is 0:
job_id = 'z_{}'.format(job_id)
return re.sub('[^0-9a-zA-Z]+', '_', job_id)
job = 'z_{}'.format(job_id)
else:
job = job_id

# Clean up 'bad' characters except templates
tracker = 0
cleansed_job_id = ''
for m in re.finditer(r'\{{2}.+?\}{2}', job):
cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_',
job[tracker:m.start()])
cleansed_job_id += job[m.start():m.end()]
tracker = m.end()

# Clean up last substring or the full string if no templates
cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker:])

return cleansed_job_id


class MLEngineBatchPredictionOperator(BaseOperator):
Expand Down Expand Up @@ -132,6 +89,8 @@ class MLEngineBatchPredictionOperator(BaseOperator):
if the desired model version is
"projects/my_project/models/my_model/versions/my_version".
See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs
for further documentation on the parameters.
:param project_id: The Google Cloud project name where the
prediction job is submitted.
Expand Down Expand Up @@ -197,7 +156,14 @@ class MLEngineBatchPredictionOperator(BaseOperator):
"""

template_fields = [
"prediction_job_request",
'_project_id',
'_job_id',
'_region',
'_input_paths',
'_output_path',
'_model_name',
'_version_name',
'_uri',
]

@apply_defaults
Expand All @@ -219,45 +185,91 @@ def __init__(self,
**kwargs):
super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)

self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self._project_id = project_id
self._job_id = job_id
self._region = region
self._data_format = data_format
self._input_paths = input_paths
self._output_path = output_path
self._model_name = model_name
self._version_name = version_name
self._uri = uri
self._max_worker_count = max_worker_count
self._runtime_version = runtime_version
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to

try:
prediction_input = _create_prediction_input(
project_id, region, data_format, input_paths, output_path,
model_name, version_name, uri, max_worker_count,
runtime_version)
except ValueError as e:
self.log.error(
'Cannot create batch prediction job request due to: %s',
e
)
raise
if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
if not self._job_id:
raise AirflowException(
'An unique job id is required for Google MLEngine prediction '
'job.')

self.prediction_job_request = {
'jobId': _normalize_mlengine_job_id(job_id),
'predictionInput': prediction_input
}
if self._uri:
if self._model_name or self._version_name:
raise AirflowException('Ambiguous model origin: Both uri and '
'model/version name are provided.')

if self._version_name and not self._model_name:
raise AirflowException(
'Missing model: Batch prediction expects '
'a model name when a version name is provided.')

if not (self._uri or self._model_name):
raise AirflowException(
'Missing model origin: Batch prediction expects a model, '
'a model & version combination, or a URI to a savedModel.')

def execute(self, context):
hook = MLEngineHook(self.gcp_conn_id, self.delegate_to)
job_id = _normalize_mlengine_job_id(self._job_id)
prediction_request = {
'jobId': job_id,
'predictionInput': {
'dataFormat': self._data_format,
'inputPaths': self._input_paths,
'outputPath': self._output_path,
'region': self._region
}
}

if self._uri:
prediction_request['predictionInput']['uri'] = self._uri
elif self._model_name:
origin_name = 'projects/{}/models/{}'.format(
self._project_id, self._model_name)
if not self._version_name:
prediction_request['predictionInput'][
'modelName'] = origin_name
else:
prediction_request['predictionInput']['versionName'] = \
origin_name + '/versions/{}'.format(self._version_name)

if self._max_worker_count:
prediction_request['predictionInput'][
'maxWorkerCount'] = self._max_worker_count

if self._runtime_version:
prediction_request['predictionInput'][
'runtimeVersion'] = self._runtime_version

hook = MLEngineHook(self._gcp_conn_id, self._delegate_to)

# Helper method to check if the existing job's prediction input is the
# same as the request we get here.
def check_existing_job(existing_job):
return existing_job.get('predictionInput', None) == \
self.prediction_job_request['predictionInput']
prediction_request['predictionInput']

try:
finished_prediction_job = hook.create_job(
self.project_id,
self.prediction_job_request,
check_existing_job)
self._project_id, prediction_request, check_existing_job)
except errors.HttpError:
raise

if finished_prediction_job['state'] != 'SUCCEEDED':
self.log.error(
'Batch prediction job failed: %s',
str(finished_prediction_job))
self.log.error('MLEngine batch prediction job failed: {}'.format(
str(finished_prediction_job)))
raise RuntimeError(finished_prediction_job['errorMessage'])

return finished_prediction_job['predictionOutput']
Expand Down Expand Up @@ -419,9 +431,8 @@ def execute(self, context):
return hook.create_version(self._project_id, self._model_name,
self._version)
elif self._operation == 'set_default':
return hook.set_default_version(
self._project_id, self._model_name,
self._version['name'])
return hook.set_default_version(self._project_id, self._model_name,
self._version['name'])
elif self._operation == 'list':
return hook.list_versions(self._project_id, self._model_name)
elif self._operation == 'delete':
Expand Down Expand Up @@ -546,7 +557,8 @@ def execute(self, context):

if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
self.log.info('MLEngine Training job request is: {}'.format(training_request))
self.log.info('MLEngine Training job request is: {}'.format(
training_request))
return

hook = MLEngineHook(
Expand All @@ -557,6 +569,7 @@ def execute(self, context):
def check_existing_job(existing_job):
return existing_job.get('trainingInput', None) == \
training_request['trainingInput']

try:
finished_training_job = hook.create_job(
self._project_id, training_request, check_existing_job)
Expand Down
Loading

0 comments on commit 65e7025

Please sign in to comment.