-
-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathgenerate_tests_from_examples.py
executable file
·108 lines (92 loc) · 3.88 KB
/
generate_tests_from_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
import os
import re
import sys
from itertools import takewhile
from pathlib import Path
from shutil import rmtree
from typing import Iterable, Iterator, Tuple
ROOT_DIR = Path(__file__).parent.parent
EXAMPLES_PATH = ROOT_DIR / "examples"
GENERATED_PATH = ROOT_DIR / "tests" / "__generated__"
with open(ROOT_DIR / "scripts" / "test_wrapper.py") as wrapper_file:
before_lines = [*takewhile(lambda l: not l.startswith("##"), wrapper_file), "##\n"]
after_lines = ["##\n", *wrapper_file]
def iter_paths() -> Iterator[Tuple[Path, Path]]:
for example_path in EXAMPLES_PATH.glob("**/*.py"):
if example_path.name == "__init__.py":
continue
relative_path = example_path.relative_to(EXAMPLES_PATH)
test_dir = GENERATED_PATH / relative_path.parent
test_dir.mkdir(parents=True, exist_ok=True)
yield example_path, test_dir / f"test_{relative_path.name}"
INDENTATION = 4 * " "
union_regex = re.compile(r"..(\w+(\[.+?\])? \| )+(\w+)")
# regex is not recursive and thus cannot catch things like Connection[Ship | None] | None
try:
from re import Match
except ImportError:
Match = ... # type: ignore
def replace_union(match: Match) -> str:
args = list(map(str.strip, match.group(0)[2:].split("|")))
if match.group(0)[0] == "=" and args[-1] != "None": # graphql types
return match.group(0)
joined = ", ".join(args)
return match.group(0)[:2] + f"Union[{joined}]"
def handle_union(line: str) -> str:
return union_regex.sub(replace_union, line)
def main():
if GENERATED_PATH.exists():
rmtree(GENERATED_PATH)
GENERATED_PATH.mkdir(parents=True)
for example_path, test_path in iter_paths():
example: Iterable[str]
with open(example_path) as example:
with open(test_path, "w") as test:
if (
sys.version_info < (3, 10)
or os.getenv("TOXENV", None) != "py310"
or True
):
example = map(handle_union, example)
# 3.9 compatibility is added after __future__ import
# However, Annotated/Literal/etc. can be an issue
first_line = next(example)
if first_line.startswith("from __future__ import"):
test.write(first_line)
test.writelines(before_lines)
else:
test.writelines(before_lines)
test.write(first_line)
test_count = 0
while example:
# Classes must be declared in global namespace in order to get
# get_type_hints and is_method to work
# Test function begin at the first assertion.
for line in example:
if line.startswith("assert ") or line.startswith(
"with raises("
):
test.write(f"def {test_path.stem}{test_count}():\n")
test.write(INDENTATION + line)
break
test.write(line)
else:
break
cur_indent = INDENTATION
for line in example:
if any(line.startswith(s) for s in ("class ", "@")):
test.write(line)
test_count += 1
break
test.write(cur_indent + line)
if '"""' in line:
cur_indent = "" if cur_indent else INDENTATION
else:
break
test.writelines(after_lines)
for path in GENERATED_PATH.glob("**"):
if path.is_dir():
open(path / "__init__.py", "w").close()
if __name__ == "__main__":
main()