Skip to content

Commit

Permalink
Add Livy Operator with deferrable mode (apache#29047)
Browse files Browse the repository at this point in the history
This PR donates the following LivyOperator with a deferrable mode that wraps the Apache Livy batch REST API, allowing submission of a Spark application to the underlying cluster asynchronously. This was developed in astronomer-providers repo to apache airflow.


Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
sunank200 and Lee-W authored Feb 22, 2023
1 parent 8430d60 commit 47ebe99
Show file tree
Hide file tree
Showing 11 changed files with 1,402 additions and 11 deletions.
390 changes: 389 additions & 1 deletion airflow/providers/apache/livy/hooks/livy.py

Large diffs are not rendered by default.

52 changes: 45 additions & 7 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
from airflow.providers.apache.livy.triggers.livy import LivyTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -56,6 +57,7 @@ class LivyOperator(BaseOperator):
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param retry_args: Arguments which define the retry behaviour.
:param deferrable: Run operator in the deferrable mode
See Tenacity documentation at https://github.com/jd/tenacity
"""

Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
retry_args: dict[str, Any] | None = None,
deferrable: bool = False,
**kwargs: Any,
) -> None:

Expand Down Expand Up @@ -120,6 +123,7 @@ def __init__(
self._livy_hook: LivyHook | None = None
self._batch_id: int | str
self.retry_args = retry_args
self.deferrable = deferrable

def get_hook(self) -> LivyHook:
"""
Expand All @@ -138,13 +142,27 @@ def get_hook(self) -> LivyHook:

def execute(self, context: Context) -> Any:
self._batch_id = self.get_hook().post_batch(**self.spark_params)

if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)

context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])

return self._batch_id
self.log.info("Generated batch-id is %s", self._batch_id)

# Wait for the job to complete
if not self.deferrable:
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
return self._batch_id

self.defer(
timeout=self.execution_timeout,
trigger=LivyTrigger(
batch_id=self._batch_id,
spark_params=self.spark_params,
livy_conn_id=self._livy_conn_id,
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
),
method_name="execute_complete",
)

def poll_for_termination(self, batch_id: int | str) -> None:
"""
Expand All @@ -170,3 +188,23 @@ def kill(self) -> None:
"""Delete the current batch session."""
if self._batch_id is not None:
self.get_hook().delete_batch(self._batch_id)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
# dump the logs from livy to worker through triggerer.
if event.get("log_lines", None) is not None:
for log_line in event["log_lines"]:
self.log.info(log_line)

if event["status"] == "error":
raise AirflowException(event["response"])
self.log.info(
"%s completed with response %s",
self.task_id,
event["response"],
)
return event["batch_id"]
2 changes: 2 additions & 0 deletions airflow/providers/apache/livy/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ versions:
dependencies:
- apache-airflow>=2.3.0
- apache-airflow-providers-http
- aiohttp
- asgiref

integrations:
- integration-name: Apache Livy
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/apache/livy/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
146 changes: 146 additions & 0 deletions airflow/providers/apache/livy/triggers/livy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains the Apache Livy Trigger."""
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class LivyTrigger(BaseTrigger):
"""
Check for the state of a previously submitted job with batch_id
:param batch_id: Batch job id
:param spark_params: Spark parameters; for example,
spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi",
"args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar",
"driver_cores": 1, "executor_cores": 4,"num_executors": 1}
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that
case return the batch_id and if polling_interval > 0, poll the livy job for termination in the
polling interval defined.
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param livy_hook_async: LivyAsyncHook object
"""

def __init__(
self,
batch_id: int | str,
spark_params: dict[Any, Any],
livy_conn_id: str = "livy_default",
polling_interval: int = 0,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
):
super().__init__()
self._batch_id = batch_id
self.spark_params = spark_params
self._livy_conn_id = livy_conn_id
self._polling_interval = polling_interval
self._extra_options = extra_options
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes LivyTrigger arguments and classpath."""
return (
"airflow.providers.apache.livy.triggers.livy.LivyTrigger",
{
"batch_id": self._batch_id,
"spark_params": self.spark_params,
"livy_conn_id": self._livy_conn_id,
"polling_interval": self._polling_interval,
"extra_options": self._extra_options,
"extra_headers": self._extra_headers,
"livy_hook_async": self._livy_hook_async,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Checks if the _polling_interval > 0, in that case it pools Livy for
batch termination asynchronously.
else returns the success response
"""
try:
if self._polling_interval > 0:
response = await self.poll_for_termination(self._batch_id)
yield TriggerEvent(response)
yield TriggerEvent(
{
"status": "success",
"batch_id": self._batch_id,
"response": f"Batch {self._batch_id} succeeded",
"log_lines": None,
}
)
except Exception as exc:
yield TriggerEvent(
{
"status": "error",
"batch_id": self._batch_id,
"response": f"Batch {self._batch_id} did not succeed with {str(exc)}",
"log_lines": None,
}
)

async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]:
"""
Pool Livy for batch termination asynchronously.
:param batch_id: id of the batch session to monitor.
"""
hook = self._get_async_hook()
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
while state["batch_state"] not in hook.TERMINAL_STATES:
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
self.log.info("Sleeping for %s seconds", self._polling_interval)
await asyncio.sleep(self._polling_interval)
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
log_lines = await hook.dump_batch_logs(batch_id)
if state["batch_state"] != BatchState.SUCCESS:
return {
"status": "error",
"batch_id": batch_id,
"response": f"Batch {batch_id} did not succeed",
"log_lines": log_lines,
}
return {
"status": "success",
"batch_id": batch_id,
"response": f"Batch {batch_id} succeeded",
"log_lines": log_lines,
}

def _get_async_hook(self) -> LivyAsyncHook:
if self._livy_hook_async is None or not isinstance(self._livy_hook_async, LivyAsyncHook):
self._livy_hook_async = LivyAsyncHook(
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
)
return self._livy_hook_async
4 changes: 3 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@
},
"apache.livy": {
"deps": [
"aiohttp",
"apache-airflow-providers-http",
"apache-airflow>=2.3.0"
"apache-airflow>=2.3.0",
"asgiref"
],
"cross-providers-deps": [
"http"
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/apache/livy/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

__all__ = ["async_mock", "AsyncMock"]

import sys

if sys.version_info < (3, 8):
# For compatibility with Python 3.7
from asynctest import mock as async_mock

# ``asynctest.mock.CoroutineMock`` which provide compatibility not working well with autospec=True
# as result "TypeError: object MagicMock can't be used in 'await' expression" could be raised.
# Best solution in this case provide as spec actual awaitable object
# >>> from tests.providers.apache.livy.compat import AsyncMock
# >>> from foo.bar import SpamEgg
# >>> mock_something = AsyncMock(SpamEgg)
from asynctest.mock import CoroutineMock as AsyncMock
else:
from unittest import mock as async_mock
from unittest.mock import AsyncMock
Loading

0 comments on commit 47ebe99

Please sign in to comment.