diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py index 1c8610a32e231..54b8e36015858 100644 --- a/airflow/providers/microsoft/azure/operators/azure_batch.py +++ b/airflow/providers/microsoft/azure/operators/azure_batch.py @@ -232,6 +232,10 @@ def _check_inputs(self) -> Any: self.vm_publisher, self.vm_offer, self.sku_starts_with ) ) + if not self.target_dedicated_nodes and not self.enable_auto_scale: + raise AirflowException( + "Either target_dedicated_nodes or enable_auto_scale must be set. None was set" + ) if self.enable_auto_scale: if self.target_dedicated_nodes or self.target_low_priority_nodes: raise AirflowException( @@ -243,7 +247,7 @@ def _check_inputs(self) -> Any: ) ) if not self.auto_scale_formula: - raise AirflowException("The auto_scale_formula is required when enable_auto_scale is" " set") + raise AirflowException("The auto_scale_formula is required when enable_auto_scale is set") if self.batch_job_release_task and not self.batch_job_preparation_task: raise AirflowException( "A batch_job_release_task cannot be specified without also " diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index 3ef4f69785979..89ed85c208ba9 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -32,9 +32,14 @@ BATCH_JOB_ID = "MyJob" BATCH_TASK_ID = "MyTask" BATCH_VM_SIZE = "Standard" +FORMULA = """$curTime = time(); + $workHours = $curTime.hour >= 8 && $curTime.hour < 18; + $isWeekday = $curTime.weekday >= 1 && $curTime.weekday <= 5; + $isWorkingWeekdayHour = $workHours && $isWeekday; + $TargetDedicated = $isWorkingWeekdayHour ? 20:10;""" -class TestAzureBatchOperator(unittest.TestCase): +class TestAzureBatchOperator(unittest.TestCase): # pylint: disable=too-many-instance-attributes # set up the test environment @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook") @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") @@ -89,6 +94,40 @@ def setUp(self, mock_batch, mock_hook): ) ) self.operator = AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + batch_task_command_line="echo hello", + azure_batch_conn_id=self.test_vm_conn_id, + target_dedicated_nodes=1, + timeout=2, + ) + self.operator2_pass = AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + batch_task_command_line="echo hello", + azure_batch_conn_id=self.test_vm_conn_id, + enable_auto_scale=True, + auto_scale_formula=FORMULA, + timeout=2, + ) + self.operator2_no_formula = AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + batch_task_command_line="echo hello", + azure_batch_conn_id=self.test_vm_conn_id, + enable_auto_scale=True, + timeout=2, + ) + self.operator_fail = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, @@ -110,6 +149,15 @@ def test_execute_without_failures(self, wait_mock): self.batch_client.job.add.assert_called() self.batch_client.task.add.assert_called() + @mock.patch.object(AzureBatchHook, 'wait_for_all_node_state') + def test_execute_without_failures_2(self, wait_mock): + wait_mock.return_value = True # No wait + self.operator2_pass.execute(None) + # test pool creation + self.batch_client.pool.add.assert_called() + self.batch_client.job.add.assert_called() + self.batch_client.task.add.assert_called() + @mock.patch.object(AzureBatchHook, 'wait_for_all_node_state') def test_execute_with_failures(self, wait_mock): wait_mock.return_value = True # No wait @@ -129,3 +177,20 @@ def test_execute_with_cleaning(self, mock_clean, wait_mock): self.operator.execute(None) mock_clean.assert_called() mock_clean.assert_called_once_with(job_id=BATCH_JOB_ID) + + @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") + def test_operator_fails(self, wait_mock): + wait_mock.return_value = True + with self.assertRaises(AirflowException) as e: + self.operator_fail.execute(None) + self.assertEqual( + str(e.exception), + "Either target_dedicated_nodes or enable_auto_scale " "must be set. None was set", + ) + + @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") + def test_operator_fails_no_formula(self, wait_mock): + wait_mock.return_value = True + with self.assertRaises(AirflowException) as e: + self.operator2_no_formula.execute(None) + self.assertEqual(str(e.exception), "The auto_scale_formula is required when enable_auto_scale is set")