Skip to content

Commit

Permalink
[FEATURE] google provider - split GkeStartPodOperator execute (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv authored May 10, 2022
1 parent faae9fa commit 60a1d9d
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import os
import tempfile
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union

from google.cloud.container_v1.types import Cluster

Expand Down Expand Up @@ -336,11 +337,22 @@ def __init__(
if self.config_file:
raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.")

def execute(self, context: 'Context') -> Optional[str]:
hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id)
self.project_id = self.project_id or hook.project_id
@staticmethod
@contextmanager
def get_gke_config_file(
gcp_conn_id,
project_id: Optional[str],
cluster_name: str,
impersonation_chain: Optional[Union[str, Sequence[str]]],
regional: bool,
location: str,
use_internal_ip: bool,
) -> Generator[str, None, None]:

if not self.project_id:
hook = GoogleBaseHook(gcp_conn_id=gcp_conn_id)
project_id = project_id or hook.project_id

if not project_id:
raise AirflowException(
"The project id must be passed either as "
"keyword project_id parameter or as project_id extra "
Expand All @@ -363,15 +375,15 @@ def execute(self, context: 'Context') -> Optional[str]:
"container",
"clusters",
"get-credentials",
self.cluster_name,
cluster_name,
"--project",
self.project_id,
project_id,
]
if self.impersonation_chain:
if isinstance(self.impersonation_chain, str):
impersonation_account = self.impersonation_chain
elif len(self.impersonation_chain) == 1:
impersonation_account = self.impersonation_chain[0]
if impersonation_chain:
if isinstance(impersonation_chain, str):
impersonation_account = impersonation_chain
elif len(impersonation_chain) == 1:
impersonation_account = impersonation_chain[0]
else:
raise AirflowException(
"Chained list of accounts is not supported, please specify only one service account"
Expand All @@ -383,15 +395,28 @@ def execute(self, context: 'Context') -> Optional[str]:
impersonation_account,
]
)
if self.regional:
if regional:
cmd.append('--region')
else:
cmd.append('--zone')
cmd.append(self.location)
if self.use_internal_ip:
cmd.append(location)
if use_internal_ip:
cmd.append('--internal-ip')
execute_in_subprocess(cmd)

# Tell `KubernetesPodOperator` where the config file is located
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
yield os.environ[KUBE_CONFIG_ENV_VAR]

def execute(self, context: 'Context') -> Optional[str]:

with GKEStartPodOperator.get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
) as config_file:
self.config_file = config_file
return super().execute(context)

0 comments on commit 60a1d9d

Please sign in to comment.