forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_imports.py
131 lines (113 loc) Β· 4.57 KB
/
check_imports.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""This script checks documentation for broken import statements."""
import importlib
import json
import logging
import os
import re
import warnings
from pathlib import Path
from typing import List, Tuple
logger = logging.getLogger(__name__)
DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
import_pattern = re.compile(
r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
)
def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
"""Get (module, import) statements from a single code cell."""
import_statements = []
for line in code_lines:
line = line.strip()
if line.startswith("#") or not line:
continue
# Join lines that end with a backslash
if line.endswith("\\"):
line = line[:-1].rstrip() + " "
continue
matches = import_pattern.findall(line)
for match in matches:
if match[0]: # simple import statement
import_statements.append((match[0], ""))
else: # from ___ import statement
module, items = match[1], match[2]
items_list = items.replace(" ", "").split(",")
for item in items_list:
import_statements.append((module, item))
return import_statements
def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
"""Get (module, import) statements from a Jupyter notebook."""
with open(notebook_path, "r", encoding="utf-8") as file:
notebook = json.load(file)
code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"]
import_statements = []
for cell in code_cells:
code_lines = cell["source"]
import_statements.extend(_get_imports_from_code_cell(code_lines))
return import_statements
def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""Collect offending import statements."""
offending_imports = []
for module, item in import_statements:
try:
if item:
try:
# submodule
full_module_name = f"{module}.{item}"
importlib.import_module(full_module_name)
except ModuleNotFoundError:
# attribute
try:
imported_module = importlib.import_module(module)
getattr(imported_module, item)
except AttributeError:
offending_imports.append((module, item))
except Exception:
offending_imports.append((module, item))
else:
importlib.import_module(module)
except Exception:
offending_imports.append((module, item))
return offending_imports
def _is_relevant_import(module: str) -> bool:
"""Check if module is recognized."""
# Ignore things like langchain_{bla}, where bla is unrecognized.
recognized_packages = [
"langchain",
"langchain_core",
"langchain_community",
"langchain_experimental",
"langchain_text_splitters",
]
return module.split(".")[0] in recognized_packages
def _serialize_bad_imports(bad_files: list) -> str:
"""Serialize bad imports to a string."""
bad_imports_str = ""
for file, bad_imports in bad_files:
bad_imports_str += f"File: {file}\n"
for module, item in bad_imports:
bad_imports_str += f" {module}.{item}\n"
return bad_imports_str
def check_notebooks(directory: str) -> list:
"""Check notebooks for broken import statements."""
bad_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"):
notebook_path = os.path.join(root, file)
import_statements = [
(module, item)
for module, item in _extract_import_statements(notebook_path)
if _is_relevant_import(module)
]
bad_imports = _get_bad_imports(import_statements)
if bad_imports:
bad_files.append(
(
os.path.join(root, file),
bad_imports,
)
)
return bad_files
if __name__ == "__main__":
bad_files = check_notebooks(DOCS_DIR)
if bad_files:
raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")