Skip to content

Commit

Permalink
Ensure target_dedicated_nodes or enable_auto_scale is set in AzureBat…
Browse files Browse the repository at this point in the history
…chOperator (apache#11251)
  • Loading branch information
ephraimbuddy authored Oct 3, 2020
1 parent b7183de commit 4210618
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
6 changes: 5 additions & 1 deletion airflow/providers/microsoft/azure/operators/azure_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def _check_inputs(self) -> Any:
self.vm_publisher, self.vm_offer, self.sku_starts_with
)
)
if not self.target_dedicated_nodes and not self.enable_auto_scale:
raise AirflowException(
"Either target_dedicated_nodes or enable_auto_scale must be set. None was set"
)
if self.enable_auto_scale:
if self.target_dedicated_nodes or self.target_low_priority_nodes:
raise AirflowException(
Expand All @@ -243,7 +247,7 @@ def _check_inputs(self) -> Any:
)
)
if not self.auto_scale_formula:
raise AirflowException("The auto_scale_formula is required when enable_auto_scale is" " set")
raise AirflowException("The auto_scale_formula is required when enable_auto_scale is set")
if self.batch_job_release_task and not self.batch_job_preparation_task:
raise AirflowException(
"A batch_job_release_task cannot be specified without also "
Expand Down
67 changes: 66 additions & 1 deletion tests/providers/microsoft/azure/operators/test_azure_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@
BATCH_JOB_ID = "MyJob"
BATCH_TASK_ID = "MyTask"
BATCH_VM_SIZE = "Standard"
FORMULA = """$curTime = time();
$workHours = $curTime.hour >= 8 && $curTime.hour < 18;
$isWeekday = $curTime.weekday >= 1 && $curTime.weekday <= 5;
$isWorkingWeekdayHour = $workHours && $isWeekday;
$TargetDedicated = $isWorkingWeekdayHour ? 20:10;"""


class TestAzureBatchOperator(unittest.TestCase):
class TestAzureBatchOperator(unittest.TestCase): # pylint: disable=too-many-instance-attributes
# set up the test environment
@mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook")
@mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient")
Expand Down Expand Up @@ -89,6 +94,40 @@ def setUp(self, mock_batch, mock_hook):
)
)
self.operator = AzureBatchOperator(
task_id=TASK_ID,
batch_pool_id=BATCH_POOL_ID,
batch_pool_vm_size=BATCH_VM_SIZE,
batch_job_id=BATCH_JOB_ID,
batch_task_id=BATCH_TASK_ID,
batch_task_command_line="echo hello",
azure_batch_conn_id=self.test_vm_conn_id,
target_dedicated_nodes=1,
timeout=2,
)
self.operator2_pass = AzureBatchOperator(
task_id=TASK_ID,
batch_pool_id=BATCH_POOL_ID,
batch_pool_vm_size=BATCH_VM_SIZE,
batch_job_id=BATCH_JOB_ID,
batch_task_id=BATCH_TASK_ID,
batch_task_command_line="echo hello",
azure_batch_conn_id=self.test_vm_conn_id,
enable_auto_scale=True,
auto_scale_formula=FORMULA,
timeout=2,
)
self.operator2_no_formula = AzureBatchOperator(
task_id=TASK_ID,
batch_pool_id=BATCH_POOL_ID,
batch_pool_vm_size=BATCH_VM_SIZE,
batch_job_id=BATCH_JOB_ID,
batch_task_id=BATCH_TASK_ID,
batch_task_command_line="echo hello",
azure_batch_conn_id=self.test_vm_conn_id,
enable_auto_scale=True,
timeout=2,
)
self.operator_fail = AzureBatchOperator(
task_id=TASK_ID,
batch_pool_id=BATCH_POOL_ID,
batch_pool_vm_size=BATCH_VM_SIZE,
Expand All @@ -110,6 +149,15 @@ def test_execute_without_failures(self, wait_mock):
self.batch_client.job.add.assert_called()
self.batch_client.task.add.assert_called()

@mock.patch.object(AzureBatchHook, 'wait_for_all_node_state')
def test_execute_without_failures_2(self, wait_mock):
wait_mock.return_value = True # No wait
self.operator2_pass.execute(None)
# test pool creation
self.batch_client.pool.add.assert_called()
self.batch_client.job.add.assert_called()
self.batch_client.task.add.assert_called()

@mock.patch.object(AzureBatchHook, 'wait_for_all_node_state')
def test_execute_with_failures(self, wait_mock):
wait_mock.return_value = True # No wait
Expand All @@ -129,3 +177,20 @@ def test_execute_with_cleaning(self, mock_clean, wait_mock):
self.operator.execute(None)
mock_clean.assert_called()
mock_clean.assert_called_once_with(job_id=BATCH_JOB_ID)

@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails(self, wait_mock):
wait_mock.return_value = True
with self.assertRaises(AirflowException) as e:
self.operator_fail.execute(None)
self.assertEqual(
str(e.exception),
"Either target_dedicated_nodes or enable_auto_scale " "must be set. None was set",
)

@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails_no_formula(self, wait_mock):
wait_mock.return_value = True
with self.assertRaises(AirflowException) as e:
self.operator2_no_formula.execute(None)
self.assertEqual(str(e.exception), "The auto_scale_formula is required when enable_auto_scale is set")

0 comments on commit 4210618

Please sign in to comment.