Skip to content

Commit

Permalink
Capture warning during setup and collect tests cases (apache#39250)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Apr 25, 2024
1 parent 4920ab2 commit 5eaf173
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 49 deletions.
9 changes: 6 additions & 3 deletions contributing-docs/testing/unit_tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1152,10 +1152,13 @@ or by setting the environment variable ``CAPTURE_WARNINGS_OUTPUT``.
root@3f98e75b1ebe:/opt/airflow# pytest tests/core/ --warning-output-path=/foo/bar/spam.egg
...
========================= Warning summary. Total: 34, Unique: 16 ==========================
========================= Warning summary. Total: 28, Unique: 12 ==========================
airflow: total 11, unique 1
other: total 12, unique 4
tests: total 11, unique 11
runtest: total 11, unique 1
other: total 7, unique 1
runtest: total 7, unique 1
tests: total 10, unique 10
runtest: total 10, unique 10
Warnings saved into /foo/bar/spam.egg file.
================================= short test summary info =================================
Expand Down
49 changes: 22 additions & 27 deletions scripts/ci/testing/summarize_captured_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)


REQUIRED_FIELDS = ("category", "message", "node_id", "filename", "lineno", "group", "count")
REQUIRED_FIELDS = ("category", "message", "node_id", "filename", "lineno", "group", "count", "when")
CONSOLE_SIZE = shutil.get_terminal_size((80, 20)).columns
# Use as prefix/suffix in report output
IMPORTANT_WARNING_SIGN = {
Expand Down Expand Up @@ -71,8 +71,8 @@ def warnings_filename(suffix: str) -> str:


@functools.lru_cache(maxsize=None)
def _unique_key(*args: str) -> str:
return str(uuid5(NAMESPACE_OID, "-".join(args)))
def _unique_key(*args: str | None) -> str:
return str(uuid5(NAMESPACE_OID, "-".join(map(str, args))))


def sorted_groupby(it, grouping_key: Callable):
Expand All @@ -95,9 +95,10 @@ def count_groups(
class CapturedWarnings:
category: str
message: str
node_id: str
filename: str
lineno: int
when: str
node_id: str | None

@property
def unique_warning(self) -> str:
Expand Down Expand Up @@ -176,8 +177,8 @@ def merge_files(files: Iterator[tuple[Path, str]], output_directory: Path) -> Pa
return output_file


def group_report_warnings(group, group_records, output_directory: Path) -> None:
output_filepath = output_directory / warnings_filename(f"group-{group}")
def group_report_warnings(group, when: str, group_records, output_directory: Path) -> None:
output_filepath = output_directory / warnings_filename(f"{group}-{when}")

group_warnings: dict[str, CapturedWarnings] = {}
unique_group_warnings: dict[str, CapturedWarnings] = {}
Expand All @@ -188,27 +189,21 @@ def group_report_warnings(group, group_records, output_directory: Path) -> None:
if cw.unique_warning not in unique_group_warnings:
unique_group_warnings[cw.unique_warning] = cw

print(f" Group {group!r} ".center(CONSOLE_SIZE, "="))
print(f" Group {group!r} on {when!r} ".center(CONSOLE_SIZE, "="))
with output_filepath.open(mode="w") as fp:
for cw in group_warnings.values():
fp.write(f"{cw.output()}\n")
print(f"Saved into file: {output_filepath.as_posix()}\n")

print(f"Unique warnings within the test cases: {len(group_warnings):,}\n")
print("Top 10 Tests Cases:")
it = count_groups(
group_warnings.values(),
grouping_key=lambda cw: (
cw.category,
cw.node_id,
),
top=10,
)
for (category, node_id), count in it:
if suffix := IMPORTANT_WARNING_SIGN.get(category, ""):
suffix = f" ({suffix})"
print(f" {category} {node_id} - {count:,}{suffix}")
print()
if when == "runtest": # Node id exists only during runtest
print(f"Unique warnings within the test cases: {len(group_warnings):,}\n")
print("Top 10 Tests Cases:")
it = count_groups(group_warnings.values(), grouping_key=lambda cw: (cw.category, cw.node_id), top=10)
for (category, node_id), count in it:
if suffix := IMPORTANT_WARNING_SIGN.get(category, ""):
suffix = f" ({suffix})"
print(f" {category} {node_id} - {count:,}{suffix}")
print()

print(f"Unique warnings: {len(unique_group_warnings):,}\n")
print("Warnings grouped by category:")
Expand All @@ -232,8 +227,6 @@ def group_report_warnings(group, group_records, output_directory: Path) -> None:
if always:
print(f" Always reported warnings {len(always):,}".center(CONSOLE_SIZE, "-"))
for cw in always:
if prefix := IMPORTANT_WARNING_SIGN.get(cw.category, ""):
prefix = f" ({prefix})"
print(f"{cw.filename}:{cw.lineno}")
print(f" {cw.category} - {cw.message}")
print()
Expand All @@ -243,8 +236,10 @@ def split_by_groups(output_file: Path, output_directory: Path) -> None:
records: list[dict] = []
with output_file.open() as fp:
records.extend(map(json.loads, fp))
for group, group_records in sorted_groupby(records, grouping_key=lambda record: record["group"]):
group_report_warnings(group, group_records, output_directory)
for (group, when), group_records in sorted_groupby(
records, grouping_key=lambda record: (record["group"], record["when"])
):
group_report_warnings(group, when, group_records, output_directory)


def main(_input: str, _output: str | None, pattern: str | None) -> int | str:
Expand All @@ -260,7 +255,7 @@ def main(_input: str, _output: str | None, pattern: str | None) -> int | str:
print(f" Process file {input_path} ".center(CONSOLE_SIZE, "="))
if not input_path.is_file():
return f"{input_path} is not a file."
files = resolve_file(input_path, cwd)
files = resolve_file(input_path, cwd if not input_path.is_absolute() else None)
else:
if not input_path.is_dir():
return f"{input_path} is not a file."
Expand Down
82 changes: 63 additions & 19 deletions tests/_internals/capture_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
import site
import sys
import warnings
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable
from typing import Callable, Generator

import pytest
from typing_extensions import Literal

WhenTypeDef = Literal["config", "collect", "runtest"]
TESTS_DIR = Path(__file__).parents[1].resolve()


Expand All @@ -53,26 +56,47 @@ def _resolve_warning_filepath(path: str, rootpath: str):
class CapturedWarning:
category: str
message: str
node_id: str
filename: str
lineno: int
when: WhenTypeDef
node_id: str | None = None

@classmethod
def from_record(
cls, warning_message: warnings.WarningMessage, node_id: str, root_path: Path
cls, warning_message: warnings.WarningMessage, root_path: Path, node_id: str | None, when: WhenTypeDef
) -> CapturedWarning:
category = warning_message.category.__name__
if (category_module := warning_message.category.__module__) != "builtins":
category = f"{category_module}.{category}"
node_id, *_ = node_id.partition("[")
if node_id:
# Remove parametrized part from the test node
node_id, *_ = node_id.partition("[")
return cls(
category=category,
message=str(warning_message.message),
node_id=node_id,
when=when,
filename=_resolve_warning_filepath(warning_message.filename, os.fspath(root_path)),
lineno=warning_message.lineno,
)

@classmethod
@contextmanager
def capture_warnings(
cls, when: WhenTypeDef, root_path: Path, node_id: str | None = None
) -> Generator[list[CapturedWarning], None, None]:
captured_records: list[CapturedWarning] = []
try:
with warnings.catch_warnings(record=True) as records:
if not sys.warnoptions:
warnings.filterwarnings("always", category=DeprecationWarning, append=True)
warnings.filterwarnings("always", category=PendingDeprecationWarning, append=True)
yield captured_records
finally:
captured_records.extend(
cls.from_record(rec, root_path=root_path, node_id=node_id, when=when) for rec in records
)

@property
def uniq_key(self):
return self.category, self.message, self.lineno, self.lineno
Expand Down Expand Up @@ -123,25 +147,37 @@ def __init__(self, config: pytest.Config, output_path: str | None = None):
self.is_worker_node = hasattr(config, "workerinput")
self.captured_warnings: dict[CapturedWarning, int] = {}

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item: pytest.Item):
with warnings.catch_warnings(record=True) as records:
if not sys.warnoptions:
warnings.filterwarnings("always", category=DeprecationWarning, append=True)
warnings.filterwarnings("always", category=PendingDeprecationWarning, append=True)
def add_captured_warnings(self, cap_warning: list[CapturedWarning]) -> None:
for cw in cap_warning:
if cw not in self.captured_warnings:
self.captured_warnings[cw] = 1
else:
self.captured_warnings[cw] += 1

@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_collection(self, session: pytest.Session):
with CapturedWarning.capture_warnings("collect", self.root_path, None) as records:
yield
self.add_captured_warnings(records)

for record in records:
cap_warning = CapturedWarning.from_record(record, item.nodeid, root_path=self.root_path)
if cap_warning not in self.captured_warnings:
self.captured_warnings[cap_warning] = 1
else:
self.captured_warnings[cap_warning] += 1
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_load_initial_conftests(self, early_config: pytest.Config):
with CapturedWarning.capture_warnings("collect", self.root_path, None) as records:
yield
self.add_captured_warnings(records)

@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_runtest_protocol(self, item: pytest.Item):
with CapturedWarning.capture_warnings("runtest", self.root_path, item.nodeid) as records:
yield
self.add_captured_warnings(records)

@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_sessionfinish(self, session: pytest.Session, exitstatus: int):
"""Save warning captures in the session finish on xdist worker node"""
yield
with CapturedWarning.capture_warnings("config", self.root_path, None) as records:
yield
self.add_captured_warnings(records)
if self.is_worker_node and self.captured_warnings and hasattr(self.config, "workeroutput"):
self.config.workeroutput[self.node_key] = tuple(
[(cw.dumps(), count) for cw, count in self.captured_warnings.items()]
Expand Down Expand Up @@ -169,9 +205,12 @@ def sorted_groupby(it, grouping_key: Callable):
for group, grouped_data in itertools.groupby(sorted(it, key=grouping_key), key=grouping_key):
yield group, list(grouped_data)

@pytest.hookimpl(hookwrapper=True)
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_terminal_summary(self, terminalreporter, exitstatus: int, config: pytest.Config):
yield
with CapturedWarning.capture_warnings("collect", self.root_path, None) as records:
yield
self.add_captured_warnings(records)

if self.is_worker_node: # No need to print/write file on worker node
return

Expand Down Expand Up @@ -203,6 +242,11 @@ def pytest_terminal_summary(self, terminalreporter, exitstatus: int, config: pyt
f": total {sum(item[1] for item in grouped_data):,}, "
f"unique {len({item[0].uniq_key for item in grouped_data}):,}\n"
)
for when, when_data in self.sorted_groupby(grouped_data, lambda x: x[0].when):
terminalreporter.write(
f" {when}: total {sum(item[1] for item in when_data):,}, "
f"unique {len({item[0].uniq_key for item in when_data}):,}\n"
)

with self.warning_output_path.open("w") as fp:
for cw, count in self.captured_warnings.items():
Expand Down

0 comments on commit 5eaf173

Please sign in to comment.