forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytest_checker.py
126 lines (111 loc) · 4.08 KB
/
pytest_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import re
import sys
from pathlib import Path
def check_file(file_contents: str) -> bool:
"""Check file for the snippet"""
return bool(re.search(r"^if __name__ == \"__main__\":", file_contents, re.M))
def parse_json(data: str) -> dict:
return json.loads(data)
def treat_path(path: str) -> Path:
"""Treat bazel paths to filesystem paths"""
path = path[2:].replace(":", "/")
return Path(path)
def get_paths_from_parsed_data(parsed_data: dict) -> list:
# Example JSON input:
# "rule": [
# {
# "@class": "py_test",
# "@location": "/home/ubuntu/ray/python/ray/tests/BUILD:345:8",
# "@name": "//python/ray/tests:test_tracing",
# "string": [
# {
# "@name": "name",
# "@value": "test_tracing"
# },
# ],
# "list": [
# {
# "@name": "srcs",
# "label": [
# {
# "@value": "//python/ray/tests:aws/conftest.py"
# },
# {
# "@value": "//python/ray/tests:conftest.py"
# },
# {
# "@value": "//python/ray/tests:test_tracing.py"
# }
# ]
# }
# ],
# ... other fields ...
# "label": {
# "@name": "main",
# "@value": "//python/ray/tests:test_runtime_env_working_dir_remote_uri.py"
# },
# ... other fields ...
# }
# ]
#
# We want to get the location of the actual test file.
# This can be, in order of priority:
# 1. Specified as the "main" label
# 2. Specified as the ONLY "srcs" label
# 3. Specified as the "srcs" label matching the "name" of the test
# https://docs.bazel.build/versions/main/be/python.html#py_test
paths = []
for rule in parsed_data["query"]["rule"]:
name = rule["@name"]
if "label" in rule and rule["label"]["@name"] == "main":
paths.append((name, treat_path(rule["label"]["@value"])))
else:
list_args = {e["@name"]: e for e in rule["list"]}
label = list_args["srcs"]["label"]
if isinstance(label, dict):
paths.append((name, treat_path(label["@value"])))
else:
# list
string_name = next(
x["@value"] for x in rule["string"] if x["@name"] == "name"
)
main_path = next(
x["@value"] for x in label if string_name in x["@value"]
)
paths.append((name, treat_path(main_path)))
return paths
def main(data: str):
print("Checking files for the pytest snippet...")
parsed_data = parse_json(data)
paths = get_paths_from_parsed_data(parsed_data)
bad_paths = []
for name, path in paths:
# Special case for myst doc checker
if "test_myst_doc" in str(path):
continue
print(f"Checking test '{name}' | file '{path}'...")
try:
with open(path, "r") as f:
if not check_file(f.read()):
print(f"File '{path}' is missing the pytest snippet.")
bad_paths.append(path)
except FileNotFoundError:
print(f"File '{path}' is missing.")
bad_paths.append((path, "path is missing!"))
if bad_paths:
formatted_bad_paths = "\n".join([str(x) for x in bad_paths])
raise RuntimeError(
'Found py_test files without `if __name__ == "__main__":` snippet:'
f"\n{formatted_bad_paths}\n"
"If this is intentional, please add a `no_main` tag to bazel BUILD "
"entry for those files."
)
if __name__ == "__main__":
# Expects a json
# Invocation from workspace root:
# bazel query 'kind(py_test.*, tests(python/...) intersect
# attr(tags, "\bteam:ml\b", python/...) except attr(tags, "\bno_main\b",
# python/...))' --output xml | xq | python ci/lint/pytest_checker.py
data = sys.stdin.read()
main(data)