Skip to content

Commit

Permalink
Handle json encoding of V1Pod in task callback (apache#27609)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish authored Nov 16, 2022
1 parent dc03b90 commit 92389cf
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 14 deletions.
14 changes: 7 additions & 7 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,17 @@ def __init__(
self.is_failure_callback = is_failure_callback

def to_json(self) -> str:
dict_obj = self.__dict__.copy()
dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
return json.dumps(dict_obj)
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.serialization.serialized_objects import BaseSerialization

kwargs = json.loads(json_str)
simple_ti = SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
return cls(simple_task_instance=simple_ti, **kwargs)
val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
Expand Down
2 changes: 1 addition & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __str__(self) -> str:


class SerializationError(AirflowException):
"""A problem occurred when trying to serialize a DAG."""
"""A problem occurred when trying to serialize something."""


class ParamValidationError(AirflowException):
Expand Down
10 changes: 10 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,11 @@ def __eq__(self, other):
return NotImplemented

def as_dict(self):
warnings.warn(
"This method is deprecated. Use BaseSerialization.serialize.",
RemovedInAirflow3Warning,
stacklevel=2,
)
new_dict = dict(self.__dict__)
for key in new_dict:
if key in ["start_date", "end_date"]:
Expand Down Expand Up @@ -2604,6 +2609,11 @@ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:

@classmethod
def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
warnings.warn(
"This method is deprecated. Use BaseSerialization.deserialize.",
RemovedInAirflow3Warning,
stacklevel=2,
)
ti_key = TaskInstanceKey(*obj_dict.pop("key"))
start_date = None
end_date = None
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
23 changes: 17 additions & 6 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -381,7 +382,9 @@ def serialize_to_json(
return serialized_object

@classmethod
def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for recursive types in mypy
def serialize(
cls, var: Any, *, strict: bool = False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.
The serialization protocol is:
Expand All @@ -400,9 +403,11 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
return var.value
return var
elif isinstance(var, dict):
return cls._encode({str(k): cls.serialize(v) for k, v in var.items()}, type_=DAT.DICT)
return cls._encode(
{str(k): cls.serialize(v, strict=strict) for k, v in var.items()}, type_=DAT.DICT
)
elif isinstance(var, list):
return [cls.serialize(v) for v in var]
return [cls.serialize(v, strict=strict) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
Expand All @@ -427,12 +432,12 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
try:
return cls._encode(sorted(cls.serialize(v) for v in var), type_=DAT.SET)
return cls._encode(sorted(cls.serialize(v, strict=strict) for v in var), type_=DAT.SET)
except TypeError:
return cls._encode([cls.serialize(v) for v in var], type_=DAT.SET)
return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.SET)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
return cls._encode([cls.serialize(v) for v in var], type_=DAT.TUPLE)
return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.TUPLE)
elif isinstance(var, TaskGroup):
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
Expand All @@ -441,8 +446,12 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE)
else:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
raise SerializationError("Encountered unexpected type")
return str(var)

@classmethod
Expand Down Expand Up @@ -491,6 +500,8 @@ def deserialize(cls, encoded_var: Any) -> Any:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")

Expand Down
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pathlib import Path

REPO_ROOT = Path(__file__).parent.parent
36 changes: 36 additions & 0 deletions tests/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,39 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result

def test_simple_ti_roundtrip_exec_config_pod(self):
"""A callback request including a TI with an exec config with a V1Pod should safely roundtrip."""
from kubernetes.client import models as k8s

from airflow.callbacks.callback_requests import TaskCallbackRequest
from airflow.models import TaskInstance
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.bash import BashOperator

test_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="hello", namespace="ns"))
op = BashOperator(task_id="hi", executor_config={"pod_override": test_pod}, bash_command="hi")
ti = TaskInstance(task=op)
s = SimpleTaskInstance.from_ti(ti)
data = TaskCallbackRequest("hi", s).to_json()
actual = TaskCallbackRequest.from_json(data).simple_task_instance.executor_config["pod_override"]
assert actual == test_pod

def test_simple_ti_roundtrip_dates(self):
"""A callback request including a TI with an exec config with a V1Pod should safely roundtrip."""
from unittest.mock import MagicMock

from airflow.callbacks.callback_requests import TaskCallbackRequest
from airflow.models import TaskInstance
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.bash import BashOperator

op = BashOperator(task_id="hi", bash_command="hi")
ti = TaskInstance(task=op)
ti.set_state("SUCCESS", session=MagicMock())
start_date = ti.start_date
end_date = ti.end_date
s = SimpleTaskInstance.from_ti(ti)
data = TaskCallbackRequest("hi", s).to_json()
assert TaskCallbackRequest.from_json(data).simple_task_instance.start_date == start_date
assert TaskCallbackRequest.from_json(data).simple_task_instance.end_date == end_date
78 changes: 78 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pytest

from airflow.exceptions import SerializationError
from tests import REPO_ROOT


def test_recursive_serialize_calls_must_forward_kwargs():
"""Any time we recurse cls.serialize, we must forward all kwargs."""
import ast

valid_recursive_call_count = 0
file = REPO_ROOT / "airflow/serialization/serialized_objects.py"
content = file.read_text()
tree = ast.parse(content)

class_def = None
for stmt in ast.walk(tree):
if not isinstance(stmt, ast.ClassDef):
continue
if stmt.name == "BaseSerialization":
class_def = stmt

method_def = None
for elem in ast.walk(class_def):
if isinstance(elem, ast.FunctionDef):
if elem.name == "serialize":
method_def = elem
break
kwonly_args = [x.arg for x in method_def.args.kwonlyargs]

for elem in ast.walk(method_def):
if isinstance(elem, ast.Call):
if getattr(elem.func, "attr", "") == "serialize":
kwargs = {y.arg: y.value for y in elem.keywords}
for name in kwonly_args:
if name not in kwargs or getattr(kwargs[name], "id", "") != name:
ref = f"{file}:{elem.lineno}"
message = (
f"Error at {ref}; recursive calls to `cls.serialize` "
f"must forward the `{name}` argument"
)
raise Exception(message)
valid_recursive_call_count += 1
print(f"validated calls: {valid_recursive_call_count}")
assert valid_recursive_call_count > 0


def test_strict_mode():
"""If strict=True, serialization should fail when object is not JSON serializable."""

class Test:
a = 1

from airflow.serialization.serialized_objects import BaseSerialization

obj = [[[Test()]]] # nested to verify recursive behavior
BaseSerialization.serialize(obj) # does not raise
with pytest.raises(SerializationError, match="Encountered unexpected type"):
BaseSerialization.serialize(obj, strict=True) # now raises

0 comments on commit 92389cf

Please sign in to comment.