Skip to content

Commit

Permalink
Add masterConfig parameter to MLEngineStartTrainingJobOperator (apach…
Browse files Browse the repository at this point in the history
…e#10578)

Co-authored-by: antonio-davide-cali <[email protected]>
  • Loading branch information
antoniocali and antonio-davide-cali authored Sep 4, 2020
1 parent e4de728 commit 6e3d7b6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:param master_type: Cloud ML Engine machine name.
Must be set when scale_tier is CUSTOM. (templated)
:type master_type: str
:param master_config: Cloud ML Engine master config.
master_type must be set if master_config is provided. (templated)
:type master_type: dict
:param runtime_version: The Google Cloud ML runtime version to use for
training. (templated)
:type runtime_version: str
Expand Down Expand Up @@ -1147,6 +1150,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
'_region',
'_scale_tier',
'_master_type',
'_master_config',
'_runtime_version',
'_python_version',
'_job_dir',
Expand All @@ -1166,6 +1170,7 @@ def __init__(
region: str,
scale_tier: Optional[str] = None,
master_type: Optional[str] = None,
master_config: Optional[Dict] = None,
runtime_version: Optional[str] = None,
python_version: Optional[str] = None,
job_dir: Optional[str] = None,
Expand All @@ -1186,6 +1191,7 @@ def __init__(
self._region = region
self._scale_tier = scale_tier
self._master_type = master_type
self._master_config = master_config
self._runtime_version = runtime_version
self._python_version = python_version
self._job_dir = job_dir
Expand All @@ -1209,6 +1215,8 @@ def __init__(
raise AirflowException('Google Compute Engine region is required.')
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" and not self._master_type:
raise AirflowException('master_type must be set when scale_tier is CUSTOM')
if self._master_config and not self._master_type:
raise AirflowException('master_type must be set when master_config is provided')

def execute(self, context):
job_id = _normalize_mlengine_job_id(self._job_id)
Expand Down Expand Up @@ -1237,6 +1245,9 @@ def execute(self, context):
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
training_request['trainingInput']['masterType'] = self._master_type

if self._master_config:
training_request['trainingInput']['masterConfig'] = self._master_config

if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
self.log.info('MLEngine Training job request is: %s', training_request)
Expand Down
39 changes: 39 additions & 0 deletions tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,45 @@ def test_success_create_training_job(self, mock_hook):
project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY
)

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_create_training_job_with_master_config(self, mock_hook):
custom_training_default_args: dict = copy.deepcopy(self.TRAINING_DEFAULT_ARGS)
custom_training_default_args['scale_tier'] = 'CUSTOM'

training_input = copy.deepcopy(self.TRAINING_INPUT)
training_input['trainingInput']['runtimeVersion'] = '1.6'
training_input['trainingInput']['pythonVersion'] = '3.5'
training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
training_input['trainingInput']['scaleTier'] = 'CUSTOM'
training_input['trainingInput']['masterType'] = 'n1-standard-4'
training_input['trainingInput']['masterConfig'] = {
'acceleratorConfig': {'count': '1', 'type': 'NVIDIA_TESLA_P4'},
}

success_response = training_input.copy()
success_response['state'] = 'SUCCEEDED'
hook_instance = mock_hook.return_value
hook_instance.create_job.return_value = success_response

training_op = MLEngineStartTrainingJobOperator(
runtime_version='1.6',
python_version='3.5',
job_dir='gs://some-bucket/jobs/test_training',
master_type='n1-standard-4',
master_config={'acceleratorConfig': {'count': '1', 'type': 'NVIDIA_TESLA_P4'},},
**custom_training_default_args,
)
training_op.execute(MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
self.assertEqual(len(hook_instance.mock_calls), 1)
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=training_input, use_existing_job_fn=ANY
)

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_create_training_job_with_optional_args(self, mock_hook):
training_input = copy.deepcopy(self.TRAINING_INPUT)
Expand Down

0 comments on commit 6e3d7b6

Please sign in to comment.