Skip to content

Commit

Permalink
Add delete_topic to KafkaAdminClientHook and teardown logic to Ka…
Browse files Browse the repository at this point in the history
…fka integration tests (apache#40142)

* Add unit tests to Apache Kafka hooks

* Add teardown logic to integration tests of kafka hooks
  • Loading branch information
shahar1 authored Jun 8, 2024
1 parent 340d6b0 commit cbe6c2d
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 46 deletions.
3 changes: 1 addition & 2 deletions airflow/providers/apache/kafka/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, kafka_config_id=default_conn_name, *args, **kwargs):
"""Initialize our Base."""
super().__init__()
self.kafka_config_id = kafka_config_id
self.get_conn

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -74,6 +73,6 @@ def test_connection(self) -> tuple[bool, str]:
if t:
return True, "Connection successful."
except Exception as e:
False, str(e)
return False, str(e)

return False, "Failed to establish connection."
16 changes: 16 additions & 0 deletions airflow/providers/apache/kafka/hooks/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ def create_topic(
self.log.warning("The topic %s already exists.", t)
else:
raise

def delete_topic(
self,
topics: Sequence[str],
) -> None:
"""
Delete a topic.
:param topics: a list of topics to delete.
"""
admin_client = self.get_conn
futures = admin_client.delete_topics(topics)

for t, f in futures.items():
f.result()
self.log.info("The topic %s has been deleted.", t)
1 change: 0 additions & 1 deletion tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_providers_modules_should_have_tests(self):
"tests/providers/apache/hdfs/log/test_hdfs_task_handler.py",
"tests/providers/apache/hdfs/sensors/test_hdfs.py",
"tests/providers/apache/hive/plugins/test_hive.py",
"tests/providers/apache/kafka/hooks/test_base.py",
"tests/providers/celery/executors/test_celery_executor_utils.py",
"tests/providers/celery/executors/test_default_celery.py",
"tests/providers/cncf/kubernetes/backcompat/test_backwards_compat_converters.py",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ def test_hook(self):
kadmin = hook.get_conn
t = kadmin.list_topics(timeout=10).topics
assert t.get("test_2")
hook.delete_topic(topics=["test_1", "test_2"])
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.models import Connection

# Import Hook
from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
from airflow.providers.apache.kafka.hooks.consume import KafkaConsumerHook
from airflow.utils import db

Expand Down Expand Up @@ -68,3 +69,5 @@ def test_consume_messages(self):
msg = consumer.consume()

assert msg[0].value() == b"test_message"
hook = KafkaAdminClientHook(kafka_config_id="kafka_d")
hook.delete_topic(topics=[TOPIC])
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

from airflow.models import Connection
from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
from airflow.providers.apache.kafka.hooks.produce import KafkaProducerHook
from airflow.utils import db

Expand Down Expand Up @@ -61,7 +62,8 @@ def acked(err, msg):
p_hook = KafkaProducerHook(kafka_config_id="kafka_default")

producer = p_hook.get_producer()

producer.produce(topic, key="p1", value="p2", on_delivery=acked)
producer.poll(0)
producer.flush()
hook = KafkaAdminClientHook(kafka_config_id="kafka_default")
hook.delete_topic(topics=[topic])
81 changes: 81 additions & 0 deletions tests/providers/apache/kafka/hooks/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.providers.apache.kafka.hooks.base import KafkaBaseHook


class SomeKafkaHook(KafkaBaseHook):
def _get_client(self, config):
return config


@pytest.fixture
def hook():
return SomeKafkaHook()


TIMEOUT = 10


class TestKafkaBaseHook:
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_get_conn(self, mock_get_connection, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
assert hook.get_conn == config

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_get_conn_value_error(self, mock_get_connection, hook):
mock_get_connection.return_value.extra_dejson = {}
with pytest.raises(ValueError, match="must be provided"):
hook.get_conn()

@mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
connection = hook.test_connection()
admin_client.assert_called_once_with(config, timeout=10)
assert connection == (True, "Connection successful.")

@mock.patch(
"airflow.providers.apache.kafka.hooks.base.AdminClient",
return_value=MagicMock(list_topics=MagicMock(return_value=[])),
)
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection_no_topics(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
connection = hook.test_connection()
admin_client.assert_called_once_with(config, timeout=TIMEOUT)
assert connection == (False, "Failed to establish connection.")

@mock.patch("airflow.providers.apache.kafka.hooks.base.AdminClient")
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_test_connection_exception(self, mock_get_connection, admin_client, hook):
config = {"bootstrap.servers": MagicMock()}
mock_get_connection.return_value.extra_dejson = config
admin_client.return_value.list_topics.side_effect = [ValueError("some error")]
connection = hook.test_connection()
assert connection == (False, "some error")
81 changes: 57 additions & 24 deletions tests/providers/apache/kafka/hooks/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import json
import logging
from unittest.mock import MagicMock, patch

import pytest
from confluent_kafka.admin import AdminClient
from confluent_kafka import KafkaException
from confluent_kafka.admin import AdminClient, NewTopic

from airflow.models import Connection
from airflow.providers.apache.kafka.hooks.client import KafkaAdminClientHook
Expand All @@ -31,11 +33,7 @@
log = logging.getLogger(__name__)


class TestSampleHook:
"""
Test Admin Client Hook.
"""

class TestKafkaAdminClientHook:
def setup_method(self):
db.merge_conn(
Connection(
Expand All @@ -54,23 +52,58 @@ def setup_method(self):
extra=json.dumps({"socket.timeout.ms": 10}),
)
)

def test_init(self):
"""test initialization of AdminClientHook"""

# Standard Init
KafkaAdminClientHook(kafka_config_id="kafka_d")

# # Not Enough Args
with pytest.raises(ValueError):
KafkaAdminClientHook(kafka_config_id="kafka_bad")
self.hook = KafkaAdminClientHook(kafka_config_id="kafka_d")

def test_get_conn(self):
"""test get_conn"""

# Standard Init
k = KafkaAdminClientHook(kafka_config_id="kafka_d")

c = k.get_conn

assert isinstance(c, AdminClient)
assert isinstance(self.hook.get_conn, AdminClient)

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
)
def test_create_topic(self, admin_client):
mock_f = MagicMock()
admin_client.return_value.create_topics.return_value = {"topic_name": mock_f}
self.hook.create_topic(topics=[("topic_name", 0, 1)])
admin_client.return_value.create_topics.assert_called_with([NewTopic("topic_name", 0, 1)])
mock_f.result.assert_called_once()

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
)
def test_create_topic_error(self, admin_client):
mock_f = MagicMock()
kafka_exception = KafkaException()
mock_arg = MagicMock()
# mock_arg.name = "TOPIC_ALREADY_EXISTS"
kafka_exception.args = [mock_arg]
mock_f.result.side_effect = [kafka_exception]
admin_client.return_value.create_topics.return_value = {"topic_name": mock_f}
with pytest.raises(KafkaException):
self.hook.create_topic(topics=[("topic_name", 0, 1)])

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
)
def test_create_topic_warning(self, admin_client, caplog):
mock_f = MagicMock()
kafka_exception = KafkaException()
mock_arg = MagicMock()
mock_arg.name = "TOPIC_ALREADY_EXISTS"
kafka_exception.args = [mock_arg]
mock_f.result.side_effect = [kafka_exception]
admin_client.return_value.create_topics.return_value = {"topic_name": mock_f}
with caplog.at_level(
logging.WARNING, logger="airflow.providers.apache.kafka.hooks.client.KafkaAdminClientHook"
):
self.hook.create_topic(topics=[("topic_name", 0, 1)])
assert "The topic topic_name already exists" in caplog.text

@patch(
"airflow.providers.apache.kafka.hooks.client.AdminClient",
)
def test_delete_topic(self, admin_client):
mock_f = MagicMock()
admin_client.return_value.delete_topics.return_value = {"topic_name": mock_f}
self.hook.delete_topic(topics=["topic_name"])
admin_client.return_value.delete_topics.assert_called_with(["topic_name"])
mock_f.result.assert_called_once()
12 changes: 3 additions & 9 deletions tests/providers/apache/kafka/hooks/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ def setup_method(self):
extra=json.dumps({}),
)
)
self.hook = KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d")

def test_init(self):
"""test initialization of AdminClientHook"""

# Standard Init
KafkaConsumerHook(["test_1"], kafka_config_id="kafka_d")

# Not Enough Args
with pytest.raises(ValueError):
KafkaConsumerHook(["test_1"], kafka_config_id="kafka_bad")
def test_get_consumer(self):
assert self.hook.get_consumer() == self.hook.get_conn
12 changes: 3 additions & 9 deletions tests/providers/apache/kafka/hooks/test_produce.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,7 @@ def setup_method(self):
extra=json.dumps({}),
)
)
self.hook = KafkaProducerHook(kafka_config_id="kafka_d")

def test_init(self):
"""test initialization of AdminClientHook"""

# Standard Init
KafkaProducerHook(kafka_config_id="kafka_d")

# Not Enough Args
with pytest.raises(ValueError):
KafkaProducerHook(kafka_config_id="kafka_bad")
def test_get_producer(self):
assert self.hook.get_producer() == self.hook.get_conn

0 comments on commit cbe6c2d

Please sign in to comment.