Skip to content

Commit

Permalink
Add ability to specify pod_template_file in executor_config (apache#1…
Browse files Browse the repository at this point in the history
…1784)

* Add pod_template_override to executor_config

Users will be able to override the base pod_template_file on a per-task
basis.

* change docstring

* fix doc

* fix static checks

* add description
  • Loading branch information
dimberman authored Nov 5, 2020
1 parent 60cf315 commit 68ba54b
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 12 deletions.
12 changes: 12 additions & 0 deletions airflow/example_dags/example_kubernetes_executor_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def test_volume_mount():
)
# [END task_with_volume]

# [START task_with_template]
task_with_template = PythonOperator(
task_id="task_with_template",
python_callable=print_stuff,
executor_config={
"pod_template_file": "/usr/local/airflow/pod_templates/basic_template.yaml",
"pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"release": "stable"})),
},
)
# [END task_with_template]

# [START task_with_sidecar]
sidecar_task = PythonOperator(
task_id="task_with_sidecar",
Expand Down Expand Up @@ -146,3 +157,4 @@ def test_volume_mount():
start_task >> volume_task >> third_task
start_task >> other_ns_task
start_task >> sidecar_task
start_task >> task_with_template
34 changes: 27 additions & 7 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import kubernetes
from dateutil import parser
from kubernetes import client, watch
from kubernetes.client import Configuration
from kubernetes.client import Configuration, models as k8s
from kubernetes.client.rest import ApiException
from urllib3.exceptions import ReadTimeoutError

Expand All @@ -50,8 +50,8 @@
from airflow.utils.session import provide_session
from airflow.utils.state import State

# TaskInstance key, command, configuration
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any]
# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]

# key, state, pod_id, namespace, resource_version
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
Expand Down Expand Up @@ -341,13 +341,14 @@ def run_next(self, next_job: KubernetesJobType) -> None:
status
"""
self.log.info('Kubernetes job is %s', str(next_job))
key, command, kube_executor_config = next_job
key, command, kube_executor_config, pod_template_file = next_job
dag_id, task_id, execution_date, try_number = key

if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')

base_worker_pod = PodGenerator.deserialize_model_file(self.kube_config.pod_template_file)
base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config)

if not base_worker_pod:
raise AirflowException(
f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}"
Expand Down Expand Up @@ -505,6 +506,21 @@ def create_pod_id(dag_id: str, task_id: str) -> str:
return safe_dag_id + safe_task_id


def get_base_pod_from_template(pod_template_file: Optional[str], kube_config: Any) -> k8s.V1Pod:
"""
Reads either the pod_template_file set in the executor_config or the base pod_template_file
set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor
:param pod_template_file: absolute path to a pod_template_file.yaml or None
:param kube_config: The KubeConfig class generated by airflow that contains all kube metadata
:return: a V1Pod that can be used as the base pod for k8s tasks
"""
if pod_template_file:
return PodGenerator.deserialize_model_file(pod_template_file)
else:
return PodGenerator.deserialize_model_file(kube_config.pod_template_file)


class KubernetesExecutor(BaseExecutor, LoggingMixin):
"""Executor for Kubernetes"""

Expand Down Expand Up @@ -619,10 +635,14 @@ def execute_async(
"""Executes task asynchronously"""
self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config)
kube_executor_config = PodGenerator.from_obj(executor_config)
if executor_config:
pod_template_file = executor_config.get("pod_template_override", None)
else:
pod_template_file = None
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
self.task_queue.put((key, command, kube_executor_config))
self.task_queue.put((key, command, kube_executor_config, pod_template_file))

def sync(self) -> None:
"""Synchronize task state."""
Expand Down Expand Up @@ -677,7 +697,7 @@ def sync(self) -> None:
except ApiException as e:
if e.reason == "BadRequest":
self.log.error("Request was invalid. Failing task")
key, _, _ = task
key, _, _, _ = task
self.change_state(key, State.FAILED, e)
else:
self.log.warning(
Expand Down
6 changes: 3 additions & 3 deletions chart/requirements.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dependencies:
- name: postgresql
repository: https://kubernetes-charts.storage.googleapis.com/
repository: https://kubernetes-charts.storage.googleapis.com
version: 6.3.12
digest: sha256:e8d53453861c590e6ae176331634c9268a11cf894be17ed580fa2b347101be97
generated: "2020-10-27T21:16:13.0063538Z"
digest: sha256:58d88cf56e78b2380091e9e16cc6ccf58b88b3abe4a1886dd47cd9faef5309af
generated: "2020-11-04T15:59:36.967913-08:00"
10 changes: 9 additions & 1 deletion docs/executor/kubernetes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,15 @@ name ``base`` and a second container containing your desired sidecar.
:start-after: [START task_with_sidecar]
:end-before: [END task_with_sidecar]

In the following example, we create a sidecar container that shares a volume_mount for data sharing.
You can also create custom ``pod_template_file`` on a per-task basis so that you can recycle the same base values between multiple tasks.
This will replace the default ``pod_template_file`` named in the airflow.cfg and then override that template using the ``pod_override_spec``.

Here is an example of a task with both features:

.. exampleinclude:: /../airflow/example_dags/example_kubernetes_executor_config.py
:language: python
:start-after: [START task_with_template]
:end-before: [END task_with_template]

KubernetesExecutor Architecture
################################
Expand Down
14 changes: 14 additions & 0 deletions tests/executors/test_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AirflowKubernetesScheduler,
KubernetesExecutor,
create_pod_id,
get_base_pod_from_template,
)
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
Expand Down Expand Up @@ -84,6 +85,19 @@ def test_create_pod_id(self):
pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id))
self.assertTrue(self._is_valid_pod_id(pod_name))

@unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed')
@mock.patch("airflow.kubernetes.pod_generator.PodGenerator")
@mock.patch("airflow.executors.kubernetes_executor.KubeConfig")
def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator):
pod_template_file_path = "/bar/biz"
get_base_pod_from_template(pod_template_file_path, None)
self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[0][0])
self.assertEqual(pod_template_file_path, mock_generator.mock_calls[0][1][0])
mock_kubeconfig.pod_template_file = "/foo/bar"
get_base_pod_from_template(None, mock_kubeconfig)
self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[1][0])
self.assertEqual("/foo/bar", mock_generator.mock_calls[1][1][0])

def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
safe_dag_id = pod_generator.make_safe_label_value(dag_id)
Expand Down
5 changes: 4 additions & 1 deletion tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,10 @@ def test_should_render_template(self):

self.assertEqual(len(templates), 1)
self.assertEqual(templates[0].name, 'airflow/redoc.html')
self.assertEqual(templates[0].local_context, {'openapi_spec_url': '/api/v1/openapi.yaml'})
self.assertEqual(
templates[0].local_context,
{'openapi_spec_url': '/api/v1/openapi.yaml'},
)


class TestLogView(TestBase):
Expand Down

0 comments on commit 68ba54b

Please sign in to comment.