-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_pipeline.py
283 lines (234 loc) · 9.91 KB
/
data_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import adalflow as adal
from adalflow.core.types import ModelClientType, Document, List
from adalflow.components.data_process import TextSplitter, ToEmbeddings
import os
import subprocess
import re
import glob
from adalflow.utils import get_adalflow_default_root_path
from config import configs
from adalflow.core.db import LocalDB
def extract_class_definition(content: str, class_name: str) -> str:
"""Extract a complete class definition from the content."""
lines = content.split('\n')
class_start = -1
indent_level = 0
# Find the class definition start
for i, line in enumerate(lines):
if f"class {class_name}" in line:
class_start = i
# Get the indentation level of the class
indent_level = len(line) - len(line.lstrip())
break
if class_start == -1:
return content
# Collect the entire class definition
class_lines = [lines[class_start]]
current_line = class_start + 1
while current_line < len(lines):
line = lines[current_line]
# If we hit a line with same or less indentation, we're out of the class
if line.strip() and len(line) - len(line.lstrip()) <= indent_level:
break
class_lines.append(line)
current_line += 1
return '\n'.join(class_lines)
def extract_class_name_from_query(query: str) -> str:
"""Extract class name from a query about a class."""
# Common patterns for asking about classes
patterns = [
r'class (\w+)',
r'the (\w+) class',
r'what does (\w+) do',
r'how does (\w+) work',
r'show me (\w+)',
r'explain (\w+)',
]
query = query.lower()
words = query.split()
# First try to find class name using patterns
for pattern in patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
if matches:
# Return the first match, capitalized
return matches[0].capitalize()
# If no pattern match, look for words that might be class names (capitalized words)
for word in words:
# Skip common words
if word in ['the', 'class', 'show', 'me', 'how', 'does', 'what', 'is', 'are', 'explain']:
continue
# Return any word that starts with a capital letter in the original query
original_words = query.split()
for original_word in original_words:
if original_word.lower() == word and original_word[0].isupper():
return original_word
return None
def download_github_repo(repo_url: str, local_path: str):
"""
Downloads a GitHub repository to a specified local path.
Args:
repo_url (str): The URL of the GitHub repository to clone.
local_path (str): The local directory where the repository will be cloned.
Returns:
str: The output message from the `git` command.
"""
try:
# Check if Git is installed
print(f"local_path: {local_path}")
subprocess.run(
["git", "--version"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Ensure the local path exists
os.makedirs(local_path, exist_ok=True)
# Clone the repository
result = subprocess.run(
["git", "clone", repo_url, local_path],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return result.stdout.decode("utf-8")
except subprocess.CalledProcessError as e:
return f"Error during cloning: {e.stderr.decode('utf-8')}"
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
def read_all_documents(path: str):
"""
Recursively reads all documents in a directory and its subdirectories.
Args:
path (str): The root directory path.
Returns:
list: A list of Document objects with metadata.
"""
documents = []
# File extensions to look for, prioritizing code files
code_extensions = ['.py', '.js', '.ts', '.java', '.cpp', '.c', '.go', '.rs']
doc_extensions = ['.md', '.txt', '.rst', '.json', '.yaml', '.yml']
# Process code files first
for ext in code_extensions:
files = glob.glob(f"{path}/**/*{ext}", recursive=True)
for file_path in files:
if '.venv' in file_path or 'node_modules' in file_path:
continue
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
relative_path = os.path.relpath(file_path, path)
# Determine if this is an implementation file
is_implementation = (
not relative_path.startswith('test_') and
not relative_path.startswith('app_') and
'test' not in relative_path.lower()
)
doc = Document(
text=content,
meta_data={
"file_path": relative_path,
"type": ext[1:],
"is_code": True,
"is_implementation": is_implementation,
"title": relative_path
}
)
documents.append(doc)
except Exception as e:
print(f"Error reading {file_path}: {e}")
# Then process documentation files
for ext in doc_extensions:
files = glob.glob(f"{path}/**/*{ext}", recursive=True)
for file_path in files:
if '.venv' in file_path or 'node_modules' in file_path:
continue
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
relative_path = os.path.relpath(file_path, path)
doc = Document(
text=content,
meta_data={
"file_path": relative_path,
"type": ext[1:],
"is_code": False,
"is_implementation": False,
"title": relative_path
}
)
documents.append(doc)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return documents
def prepare_data_pipeline():
"""Creates and returns the data transformation pipeline."""
splitter = TextSplitter(**configs["text_splitter"])
embedder = adal.Embedder(
model_client=configs["embedder"]["model_client"](),
model_kwargs=configs["embedder"]["model_kwargs"],
)
embedder_transformer = ToEmbeddings(
embedder=embedder, batch_size=configs["embedder"]["batch_size"]
)
data_transformer = adal.Sequential(
splitter, embedder_transformer
) # sequential will chain together splitter and embedder
return data_transformer
def transform_documents_and_save_to_db(documents: List[Document], db_path: str):
"""
Transforms a list of documents and saves them to a local database.
Args:
documents (list): A list of `Document` objects.
db_path (str): The path to the local database file.
"""
# Get the data transformer
data_transformer = prepare_data_pipeline()
# Save the documents to a local database
db = LocalDB("code_db")
db.register_transformer(transformer=data_transformer, key="split_and_embed")
db.load(documents)
db.transform(key="split_and_embed")
os.makedirs(os.path.dirname(db_path), exist_ok=True)
db.save_state(filepath=db_path)
def create_sample_documents():
"""Create some sample documents for testing."""
sample_texts = [
"""Alice is a software engineer who loves coding in Python.
She specializes in machine learning and has worked on several NLP projects.
Her favorite project was building a chatbot for customer service.""",
"""Bob is a data scientist with expertise in deep learning.
He has published papers on transformer architectures and attention mechanisms.
Recently, he's been working on improving RAG systems.""",
"""The company cafeteria serves amazing tacos on Tuesdays.
They also have a great coffee machine that makes perfect lattes.
Many employees enjoy their lunch breaks in the outdoor seating area."""
]
return [Document(text=text, meta_data={"title": f"doc_{i}"})
for i, text in enumerate(sample_texts)]
if __name__ == "__main__":
from adalflow.utils import get_logger
adal.setup_env()
# Example 1: Process a GitHub repository
print("Example 1: Processing a GitHub repository")
repo_url = "https://github.com/microsoft/LMOps"
local_path = os.path.join(get_adalflow_default_root_path(), "LMOps")
# Download the repository
print("\nDownloading repository...")
result = download_github_repo(repo_url, local_path)
print(result)
# Read documents from a specific directory
print("\nReading documents...")
target_path = os.path.join(local_path, "prompt_optimization")
documents = read_all_documents(target_path)
print(f"Found {len(documents)} documents")
# Transform and save to database
print("\nTransforming documents and saving to database...")
db_path = os.path.join(get_adalflow_default_root_path(), "db_microsft_lomps")
transform_documents_and_save_to_db(documents, db_path)
print(f"Database saved to {db_path}")
# Example 2: Process test documents
print("\nExample 2: Processing test documents")
documents = create_sample_documents()
db_path = os.path.join(get_adalflow_default_root_path(), "test_db")
transform_documents_and_save_to_db(documents, db_path)
print(f"Test database saved to {db_path}")