Skip to content

Commit

Permalink
Add on kill to ssh (apache#40377)
Browse files Browse the repository at this point in the history
  • Loading branch information
MRLab12 authored Jul 24, 2024
1 parent 68b3159 commit c139fbd
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 76 deletions.
153 changes: 80 additions & 73 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,87 +286,94 @@ def host_proxy(self) -> paramiko.ProxyCommand | None:

def get_conn(self) -> paramiko.SSHClient:
"""Establish an SSH connection to the remote host."""
self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
client = paramiko.SSHClient()

if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
# to avoid BadHostKeyException, skip loading host keys
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
else:
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507
# to avoid BadHostKeyException, skip loading and saving host keys
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
if self.client is None:
self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
client = paramiko.SSHClient()

if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
# to avoid BadHostKeyException, skip loading host keys
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
else:
client.load_system_host_keys()

connect_kwargs: dict[str, Any] = {
"hostname": self.remote_host,
"username": self.username,
"timeout": self.conn_timeout,
"compress": self.compress,
"port": self.port,
"sock": self.host_proxy,
"look_for_keys": self.look_for_keys,
"banner_timeout": self.banner_timeout,
}
if self.no_host_key_check:
self.log.warning(
"No Host Key Verification. This won't protect against Man-In-The-Middle attacks"
)
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507
# to avoid BadHostKeyException, skip loading and saving host keys
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

if self.password:
password = self.password.strip()
connect_kwargs.update(password=password)
connect_kwargs: dict[str, Any] = {
"hostname": self.remote_host,
"username": self.username,
"timeout": self.conn_timeout,
"compress": self.compress,
"port": self.port,
"sock": self.host_proxy,
"look_for_keys": self.look_for_keys,
"banner_timeout": self.banner_timeout,
}

if self.password:
password = self.password.strip()
connect_kwargs.update(password=password)

if self.pkey:
connect_kwargs.update(pkey=self.pkey)

if self.key_file:
connect_kwargs.update(key_filename=self.key_file)

if self.disabled_algorithms:
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)

def log_before_sleep(retry_state):
return self.log.info(
"Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number
)

if self.pkey:
connect_kwargs.update(pkey=self.pkey)
for attempt in Retrying(
reraise=True,
wait=wait_fixed(3) + wait_random(0, 2),
stop=stop_after_attempt(3),
before_sleep=log_before_sleep,
):
with attempt:
client.connect(**connect_kwargs)

if self.key_file:
connect_kwargs.update(key_filename=self.key_file)
if self.keepalive_interval:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no attribute "set_keepalive".
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr]

if self.disabled_algorithms:
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)
if self.ciphers:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no method `get_security_options`".
client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr]

def log_before_sleep(retry_state):
return self.log.info(
"Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number
)
self.client = client
return client

for attempt in Retrying(
reraise=True,
wait=wait_fixed(3) + wait_random(0, 2),
stop=stop_after_attempt(3),
before_sleep=log_before_sleep,
):
with attempt:
client.connect(**connect_kwargs)

if self.keepalive_interval:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no attribute "set_keepalive".
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr]

if self.ciphers:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no method `get_security_options`".
client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr]

self.client = client
return client
else:
# Return the existing connection
return self.client

@deprecated(
reason=(
Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,19 @@ def execute(self, context=None) -> bytes | str:
enable_pickling = conf.getboolean("core", "enable_xcom_pickling")
if not enable_pickling:
result = b64encode(result).decode("utf-8")

return result

def tunnel(self) -> None:
"""Get ssh tunnel."""
ssh_client = self.hook.get_conn() # type: ignore[union-attr]
ssh_client.get_transport()

def on_kill(self) -> None:
"""Close the ssh client session."""
ssh_client = self.hook.client
if ssh_client:
ssh_client.close()
self.log.info("SSH client closed.")
else:
self.log.info("No SSH client to close.")
5 changes: 4 additions & 1 deletion tests/providers/amazon/aws/transfers/test_s3_to_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def test_s3_to_sftp_operation(self):
assert not s3_hook.check_for_bucket(self.s3_bucket)

def delete_remote_resource(self):
# Initiate SHH hook
hook = SSHHook(ssh_conn_id="ssh_default")
hook.no_host_key_check = True
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_rm_file",
ssh_hook=self.hook,
ssh_hook=hook,
command=f"rm {self.sftp_path}",
do_xcom_push=True,
dag=self.dag,
Expand Down
4 changes: 4 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ def test_ssh_connection_with_no_host_key_check_true_and_allow_host_key_changes_f

ssh_mock.reset_mock()
with mock.patch("os.path.isfile", return_value=False):
# Reset ssh hook to initial state
hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_FALSE
)
with hook.get_conn():
assert ssh_mock.return_value.set_missing_host_key_policy.called is True
assert isinstance(
Expand Down
32 changes: 30 additions & 2 deletions tests/providers/ssh/operators/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

import random
import time
from datetime import timedelta
from unittest import mock

import pytest
from paramiko.client import SSHClient

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout
from airflow.models import TaskInstance
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
Expand Down Expand Up @@ -192,13 +194,14 @@ def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out):
assert task.get_pty == get_pty_out

def test_ssh_client_managed_correctly(self):
# Ensure connection gets closed once (via context_manager)
# Ensure connection gets closed once (via context_manager) using on_kill
task = SSHOperator(
task_id="test",
ssh_hook=self.hook,
command="ls",
)
task.execute()

self.hook.get_conn.assert_called_once()
self.hook.get_conn.return_value.__exit__.assert_called_once()

Expand Down Expand Up @@ -266,3 +269,28 @@ def test_push_ssh_exit_to_xcom(self, request, dag_maker):
with pytest.raises(AirflowException, match=f"SSH operator error: exit status = {ssh_exit_code}"):
ti.run()
assert ti.xcom_pull(task_ids=task.task_id, key="ssh_exit") == ssh_exit_code

def test_timeout_triggers_on_kill(self, request, dag_maker):
def command_sleep_forever(*args, **kwargs):
time.sleep(100) # This will be interrupted by the timeout

self.exec_ssh_client_command.side_effect = command_sleep_forever

with dag_maker(dag_id=f"dag_{request.node.name}"):
task = SSHOperator(
task_id="test_timeout",
ssh_hook=self.hook,
command="sleep 100",
execution_timeout=timedelta(seconds=1),
)
dr = dag_maker.create_dagrun(run_id="test_timeout")
ti = TaskInstance(task=task, run_id=dr.run_id)

with mock.patch.object(SSHOperator, "on_kill") as mock_on_kill:
with pytest.raises(AirflowTaskTimeout):
ti.run()

# Wait a bit to ensure on_kill has time to be called
time.sleep(1)

mock_on_kill.assert_called_once()

0 comments on commit c139fbd

Please sign in to comment.