Skip to content

Commit

Permalink
Default client in KafkaBaseHook (apache#40284)
Browse files Browse the repository at this point in the history
* Default client in KafkaBaseHook
Moving the definition of AdminClient as the default client for KafkaBaseHook
  • Loading branch information
riccardoforzan authored Jun 18, 2024
1 parent 5a3823d commit 68aa427
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/apache/kafka/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def _get_client(self, config):
raise NotImplementedError
def _get_client(self, config) -> Any:
return AdminClient(config)

@cached_property
def get_conn(self) -> Any:
Expand Down
5 changes: 1 addition & 4 deletions airflow/providers/apache/kafka/hooks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Sequence

from confluent_kafka import KafkaException
from confluent_kafka.admin import AdminClient, NewTopic
from confluent_kafka.admin import NewTopic

from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook

Expand All @@ -34,9 +34,6 @@ class KafkaAdminClientHook(KafkaBaseHook):
def __init__(self, kafka_config_id=KafkaBaseHook.default_conn_name) -> None:
super().__init__(kafka_config_id=kafka_config_id)

def _get_client(self, config) -> AdminClient:
return AdminClient(config)

def create_topic(
self,
topics: Sequence[Sequence[Any]],
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/apache/kafka/hooks/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_get_conn(self):
assert isinstance(self.hook.get_conn, AdminClient)

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
"airflow.providers.apache.kafka.hooks.base.AdminClient",
)
def test_create_topic(self, admin_client):
mock_f = MagicMock()
Expand All @@ -68,7 +68,7 @@ def test_create_topic(self, admin_client):
mock_f.result.assert_called_once()

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
"airflow.providers.apache.kafka.hooks.base.AdminClient",
)
def test_create_topic_error(self, admin_client):
mock_f = MagicMock()
Expand All @@ -82,7 +82,7 @@ def test_create_topic_error(self, admin_client):
self.hook.create_topic(topics=[("topic_name", 0, 1)])

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
"airflow.providers.apache.kafka.hooks.base.AdminClient",
)
def test_create_topic_warning(self, admin_client, caplog):
mock_f = MagicMock()
Expand All @@ -99,7 +99,7 @@ def test_create_topic_warning(self, admin_client, caplog):
assert "The topic topic_name already exists" in caplog.text

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
"airflow.providers.apache.kafka.hooks.base.AdminClient",
)
def test_delete_topic(self, admin_client):
mock_f = MagicMock()
Expand Down

0 comments on commit 68aa427

Please sign in to comment.