From 68ba54bbd5a275fba1a126f8e67bd69e5cf4b362 Mon Sep 17 00:00:00 2001 From: Daniel Imberman Date: Thu, 5 Nov 2020 14:48:05 -0800 Subject: [PATCH] Add ability to specify pod_template_file in executor_config (#11784) * Add pod_template_override to executor_config Users will be able to override the base pod_template_file on a per-task basis. * change docstring * fix doc * fix static checks * add description --- .../example_kubernetes_executor_config.py | 12 +++++++ airflow/executors/kubernetes_executor.py | 34 +++++++++++++++---- chart/requirements.lock | 6 ++-- docs/executor/kubernetes.rst | 10 +++++- tests/executors/test_kubernetes_executor.py | 14 ++++++++ tests/www/test_views.py | 5 ++- 6 files changed, 69 insertions(+), 12 deletions(-) diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py index 57b2c4a8435e0..5fbc5750ecca5 100644 --- a/airflow/example_dags/example_kubernetes_executor_config.py +++ b/airflow/example_dags/example_kubernetes_executor_config.py @@ -99,6 +99,17 @@ def test_volume_mount(): ) # [END task_with_volume] + # [START task_with_template] + task_with_template = PythonOperator( + task_id="task_with_template", + python_callable=print_stuff, + executor_config={ + "pod_template_file": "/usr/local/airflow/pod_templates/basic_template.yaml", + "pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"release": "stable"})), + }, + ) + # [END task_with_template] + # [START task_with_sidecar] sidecar_task = PythonOperator( task_id="task_with_sidecar", @@ -146,3 +157,4 @@ def test_volume_mount(): start_task >> volume_task >> third_task start_task >> other_ns_task start_task >> sidecar_task + start_task >> task_with_template diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index ad200e1ab97d0..b7071a0bc9113 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -32,7 +32,7 @@ import kubernetes from dateutil import parser from kubernetes import client, watch -from kubernetes.client import Configuration +from kubernetes.client import Configuration, models as k8s from kubernetes.client.rest import ApiException from urllib3.exceptions import ReadTimeoutError @@ -50,8 +50,8 @@ from airflow.utils.session import provide_session from airflow.utils.state import State -# TaskInstance key, command, configuration -KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any] +# TaskInstance key, command, configuration, pod_template_file +KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]] # key, state, pod_id, namespace, resource_version KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str] @@ -341,13 +341,14 @@ def run_next(self, next_job: KubernetesJobType) -> None: status """ self.log.info('Kubernetes job is %s', str(next_job)) - key, command, kube_executor_config = next_job + key, command, kube_executor_config, pod_template_file = next_job dag_id, task_id, execution_date, try_number = key if command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') - base_worker_pod = PodGenerator.deserialize_model_file(self.kube_config.pod_template_file) + base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config) + if not base_worker_pod: raise AirflowException( f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}" @@ -505,6 +506,21 @@ def create_pod_id(dag_id: str, task_id: str) -> str: return safe_dag_id + safe_task_id +def get_base_pod_from_template(pod_template_file: Optional[str], kube_config: Any) -> k8s.V1Pod: + """ + Reads either the pod_template_file set in the executor_config or the base pod_template_file + set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor + + :param pod_template_file: absolute path to a pod_template_file.yaml or None + :param kube_config: The KubeConfig class generated by airflow that contains all kube metadata + :return: a V1Pod that can be used as the base pod for k8s tasks + """ + if pod_template_file: + return PodGenerator.deserialize_model_file(pod_template_file) + else: + return PodGenerator.deserialize_model_file(kube_config.pod_template_file) + + class KubernetesExecutor(BaseExecutor, LoggingMixin): """Executor for Kubernetes""" @@ -619,10 +635,14 @@ def execute_async( """Executes task asynchronously""" self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config) kube_executor_config = PodGenerator.from_obj(executor_config) + if executor_config: + pod_template_file = executor_config.get("pod_template_override", None) + else: + pod_template_file = None if not self.task_queue: raise AirflowException(NOT_STARTED_MESSAGE) self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id) - self.task_queue.put((key, command, kube_executor_config)) + self.task_queue.put((key, command, kube_executor_config, pod_template_file)) def sync(self) -> None: """Synchronize task state.""" @@ -677,7 +697,7 @@ def sync(self) -> None: except ApiException as e: if e.reason == "BadRequest": self.log.error("Request was invalid. Failing task") - key, _, _ = task + key, _, _, _ = task self.change_state(key, State.FAILED, e) else: self.log.warning( diff --git a/chart/requirements.lock b/chart/requirements.lock index 715458e63560a..eb62c80284a3b 100644 --- a/chart/requirements.lock +++ b/chart/requirements.lock @@ -1,6 +1,6 @@ dependencies: - name: postgresql - repository: https://kubernetes-charts.storage.googleapis.com/ + repository: https://kubernetes-charts.storage.googleapis.com version: 6.3.12 -digest: sha256:e8d53453861c590e6ae176331634c9268a11cf894be17ed580fa2b347101be97 -generated: "2020-10-27T21:16:13.0063538Z" +digest: sha256:58d88cf56e78b2380091e9e16cc6ccf58b88b3abe4a1886dd47cd9faef5309af +generated: "2020-11-04T15:59:36.967913-08:00" diff --git a/docs/executor/kubernetes.rst b/docs/executor/kubernetes.rst index f3940d898012f..c01672623f82f 100644 --- a/docs/executor/kubernetes.rst +++ b/docs/executor/kubernetes.rst @@ -123,7 +123,15 @@ name ``base`` and a second container containing your desired sidecar. :start-after: [START task_with_sidecar] :end-before: [END task_with_sidecar] -In the following example, we create a sidecar container that shares a volume_mount for data sharing. +You can also create custom ``pod_template_file`` on a per-task basis so that you can recycle the same base values between multiple tasks. +This will replace the default ``pod_template_file`` named in the airflow.cfg and then override that template using the ``pod_override_spec``. + +Here is an example of a task with both features: + +.. exampleinclude:: /../airflow/example_dags/example_kubernetes_executor_config.py + :language: python + :start-after: [START task_with_template] + :end-before: [END task_with_template] KubernetesExecutor Architecture ################################ diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 9765669d32aa9..9d8d72f4ef0b4 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -35,6 +35,7 @@ AirflowKubernetesScheduler, KubernetesExecutor, create_pod_id, + get_base_pod_from_template, ) from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -84,6 +85,19 @@ def test_create_pod_id(self): pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id)) self.assertTrue(self._is_valid_pod_id(pod_name)) + @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed') + @mock.patch("airflow.kubernetes.pod_generator.PodGenerator") + @mock.patch("airflow.executors.kubernetes_executor.KubeConfig") + def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator): + pod_template_file_path = "/bar/biz" + get_base_pod_from_template(pod_template_file_path, None) + self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[0][0]) + self.assertEqual(pod_template_file_path, mock_generator.mock_calls[0][1][0]) + mock_kubeconfig.pod_template_file = "/foo/bar" + get_base_pod_from_template(None, mock_kubeconfig) + self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[1][0]) + self.assertEqual("/foo/bar", mock_generator.mock_calls[1][1][0]) + def test_make_safe_label_value(self): for dag_id, task_id in self._cases(): safe_dag_id = pod_generator.make_safe_label_value(dag_id) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index e17180579447c..901e081885a7b 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -1164,7 +1164,10 @@ def test_should_render_template(self): self.assertEqual(len(templates), 1) self.assertEqual(templates[0].name, 'airflow/redoc.html') - self.assertEqual(templates[0].local_context, {'openapi_spec_url': '/api/v1/openapi.yaml'}) + self.assertEqual( + templates[0].local_context, + {'openapi_spec_url': '/api/v1/openapi.yaml'}, + ) class TestLogView(TestBase):