forked from apache/airflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
import_all_classes.py
executable file
·141 lines (127 loc) · 5.63 KB
/
import_all_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import importlib
import pkgutil
import sys
import traceback
import warnings
from inspect import isclass
from typing import List, Optional, Set, Tuple
from warnings import WarningMessage
from rich.console import Console
console = Console(width=400, color_system="standard")
error_console = Console(width=400, color_system="standard", stderr=True)
def import_all_classes(
paths: List[str],
prefix: str,
provider_ids: Optional[List[str]] = None,
print_imports: bool = False,
print_skips: bool = False,
) -> Tuple[List[str], List[WarningMessage]]:
"""
Imports all classes in providers packages. This method loads and imports
all the classes found in providers, so that we can find all the subclasses
of operators/sensors etc.
:param paths: list of paths to look the provider packages in
:param prefix: prefix to add
:param provider_ids - provider ids that should be loaded.
:param print_imports - if imported class should also be printed in output
:param print_skips - if skipped classes should also be printed in output
:return: tuple of list of all imported classes and all warnings generated
"""
imported_classes = []
tracebacks: List[Tuple[str, str]] = []
printed_packages: Set[str] = set()
def mk_prefix(provider_id):
return f'{prefix}{provider_id}'
if provider_ids:
provider_prefixes = [mk_prefix(provider_id) for provider_id in provider_ids]
else:
provider_prefixes = [prefix]
def onerror(_):
nonlocal tracebacks
exception_string = traceback.format_exc()
for provider_prefix in provider_prefixes:
if provider_prefix in exception_string:
start_index = exception_string.find(provider_prefix)
end_index = exception_string.find("\n", start_index + len(provider_prefix))
package = exception_string[start_index:end_index]
tracebacks.append((package, exception_string))
break
all_warnings: List[WarningMessage] = []
for modinfo in pkgutil.walk_packages(path=paths, prefix=prefix, onerror=onerror):
if not any(modinfo.name.startswith(provider_prefix) for provider_prefix in provider_prefixes):
if print_skips:
console.print(f"Skipping module: {modinfo.name}")
continue
if print_imports:
package_to_print = ".".join(modinfo.name.split(".")[:-1])
if package_to_print not in printed_packages:
printed_packages.add(package_to_print)
console.print(f"Importing package: {package_to_print}")
try:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=DeprecationWarning)
_module = importlib.import_module(modinfo.name)
for attribute_name in dir(_module):
class_name = modinfo.name + "." + attribute_name
attribute = getattr(_module, attribute_name)
if isclass(attribute):
imported_classes.append(class_name)
if w:
all_warnings.extend(w)
except Exception:
exception_str = traceback.format_exc()
tracebacks.append((modinfo.name, exception_str))
if tracebacks:
console.print(
"""
[red]ERROR: There were some import errors[/]
""",
)
error_console.print("[red]----------------------------------------[/]")
for package, trace in tracebacks:
error_console.print(f"Exception when importing: {package}\n\n")
error_console.print(trace)
error_console.print("[red]----------------------------------------[/]")
sys.exit(1)
else:
return imported_classes, all_warnings
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Perform import of all provider classes.')
parser.add_argument('--path', action='append', help='paths to search providers in')
parser.add_argument('--prefix', help='prefix to add in front of the class', default='airflow.providers.')
args = parser.parse_args()
print()
print(f"Walking all packages in {args.path} with prefix {args.prefix}")
print()
classes, warns = import_all_classes(
print_imports=True, print_skips=True, paths=args.path, prefix=args.prefix
)
if len(classes) == 0:
print("[red]Something is seriously wrong - no classes imported[/]")
sys.exit(1)
if warns:
print("[yellow]There were warnings generated during the import[/]")
for w in warns:
one_line_message = str(w.message).replace('\n', ' ')
print(f"[yellow]{w.filename}:{w.lineno}: {one_line_message}[/]")
print()
print(f"[green]SUCCESS: All provider packages are importable! Imported {len(classes)} classes.[/]")
print()