forked from qodo-ai/qodo-cover
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnitTestGenerator.py
301 lines (268 loc) · 12.7 KB
/
UnitTestGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import os
import time
import json
from cover_agent.Runner import Runner
from cover_agent.CoverageProcessor import CoverageProcessor
from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller
from cover_agent.FilePreprocessor import FilePreprocessor
class UnitTestGenerator:
def __init__(
self,
prompt_template_path: str,
source_file_path: str,
test_file_path: str,
code_coverage_report_path: str,
test_command: str,
test_command_dir: str = os.getcwd(),
included_files: list = None,
coverage_type="cobertura",
desired_coverage: int = 90, # Default to 90% coverage if not specified
additional_instructions: str = "",
):
"""
Initialize the object with the provided paths and settings.
Parameters:
prompt_template_path (str): The path to the prompt template file.
source_file_path (str): The path to the source file.
test_file_path (str): The path to the test file.
code_coverage_report_path (str): The path to the code coverage report.
test_command (str): The command to run the tests.
test_command_dir (str): The directory to run the test command in. Defaults to the current working directory.
included_files (list, optional): List of included file paths. Defaults to None.
coverage_type (str): Type of coverage report. Defaults to "cobertura".
desired_coverage (int): The desired coverage percentage.
Returns:
None
"""
# Class variables
self.prompt_template_path = prompt_template_path
self.source_file_path = source_file_path
self.test_file_path = test_file_path
self.code_coverage_report_path = code_coverage_report_path
self.test_command = test_command
self.test_command_dir = test_command_dir
self.included_files = self.get_included_files(included_files)
self.coverage_type = coverage_type
self.desired_coverage = desired_coverage
self.additional_instructions = additional_instructions
# Get the logger instance from CustomLogger
self.logger = CustomLogger.get_logger(__name__)
# States to maintain within this class
self.preprocessor = FilePreprocessor(self.test_file_path)
self.failed_test_runs = []
# Run coverage and build the prompt
self.run_coverage()
self.prompt = self.build_prompt()
def run_coverage(self):
"""
Perform an initial build/test command to generate coverage report and get a baseline.
Parameters:
- None
Returns:
- None
"""
# Perform an initial build/test command to generate coverage report and get a baseline
self.logger.info(
f'Running initial build/test command to generate coverage report: "{self.test_command}"'
)
stdout, stderr, exit_code, time_of_test_command = Runner.run_command(
command=self.test_command, cwd=self.test_command_dir
)
assert (
exit_code == 0
), f"Fatal: Error running test command. Failed with exit code {exit_code}. \nStdout: \n{stdout} \nStderr: \n{stderr}"
# Instantiate CoverageProcessor and process the coverage report
coverage_processor = CoverageProcessor(
file_path=self.code_coverage_report_path,
filename=os.path.basename(self.source_file_path),
coverage_type=self.coverage_type,
)
# Use the process_coverage_report method of CoverageProcessor, passing in the time the test command was executed
try:
lines_covered, lines_missed, percentage_covered = (
coverage_processor.process_coverage_report(
time_of_test_command=time_of_test_command
)
)
# Process the extracted coverage metrics
self.current_coverage = percentage_covered
self.code_coverage_report = f"Lines covered: {lines_covered}\nLines missed: {lines_missed}\nPercentage covered: {round(percentage_covered * 100, 2)}%"
except AssertionError as error:
# Handle the case where the coverage report does not exist or was not updated after the test command
self.logger.error(f"Error in coverage processing: {error}")
# Optionally, re-raise the error or handle it as deemed appropriate for your application
raise
except (ValueError, NotImplementedError) as e:
# Handle errors related to unsupported coverage report types or issues in parsing
self.logger.warning(f"Error parsing coverage report: {e}")
self.logger.info(
"Will default to using the full coverage report. You will need to check coverage manually for each passing test."
)
with open(self.code_coverage_report_path, "r") as f:
self.code_coverage_report = f.read()
@staticmethod
def get_included_files(included_files):
"""
A method to read and concatenate the contents of included files into a single string.
Parameters:
included_files (list): A list of paths to included files.
Returns:
str: A string containing the concatenated contents of the included files, or an empty string if the input list is empty.
"""
if included_files:
included_files_content = []
for file_path in included_files:
try:
with open(file_path, "r") as file:
included_files_content.append(file.read())
except IOError as e:
print(f"Error reading file {file_path}: {str(e)}")
return "\n".join(included_files_content) if included_files_content else None
return ""
def build_prompt(self):
# Check for existence of failed tests:
if not self.failed_test_runs:
failed_test_runs_value = ""
else:
failed_test_runs_value = json.dumps(self.failed_test_runs).replace(
"\\n", "\n"
)
# Call PromptBuilder to build the prompt
prompt = PromptBuilder(
prompt_template_path=self.prompt_template_path,
source_file_path=self.source_file_path,
test_file_path=self.test_file_path,
code_coverage_report=self.code_coverage_report,
included_files=self.included_files,
additional_instructions=self.additional_instructions,
failed_test_runs=failed_test_runs_value,
)
return prompt.build_prompt()
def generate_tests(self, LLM_model="gpt-4o", max_tokens=4096, dry_run=False):
# Call AICaller to generate the tests
ai_caller = AICaller(LLM_model)
self.prompt = self.build_prompt()
self.logger.info(
f"Token count for LLM model {LLM_model}: {ai_caller.count_tokens(self.prompt)}"
)
if dry_run:
# Provide a canned response. Used for testing.
response = "```def test_something():\n pass```\n```def test_something_else():\n pass```\n```def test_something_different():\n pass```"
else:
# Tests should return with triple backticks in between tests.
# We want to remove them and split up the tests into a list of tests
response = ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens)
# Split the response into a list of tests and strip off the trailing whitespaces
# (as we sometimes anticipate indentations in the returned code from the LLM)
tests = response.split("```")
return [test.rstrip() for test in tests if test.rstrip()]
def validate_test(self, generated_test: str):
"""
Validate a single generated test case by running it and checking coverage.
This function appends the generated test to the test file, runs it, and checks the output.
If the test fails or does not increase coverage, it rolls back changes and records the failure.
Parameters:
generated_test (str): The test code to validate.
Returns:
dict: A dictionary containing the test result status, reason for failure (if any),
stdout, stderr, exit code, and the test itself.
"""
# Step 0: Run the test through the preprocessor rule set
processed_test = self.preprocessor.process_file(generated_test)
# Step 1: Append the generated test to the test file and save the original content
with open(self.test_file_path, "r+") as test_file:
original_content = test_file.read() # Store original content
test_file.write(
"\n"
+ ("\n" if not original_content.endswith("\n") else "")
+ processed_test
+ "\n"
) # Append the new test at the end
# Step 2: Run the test using the Runner class
self.logger.info(
f'Running test with the following command: "{self.test_command}"'
)
stdout, stderr, exit_code, time_of_test_command = Runner.run_command(
command=self.test_command, cwd=self.test_command_dir
)
# Step 3: Check for pass/fail from the Runner object
if exit_code != 0:
# Test failed, roll back the test file to its original content
with open(self.test_file_path, "w") as test_file:
test_file.write(original_content)
self.logger.error(f"Test failed. Rolling back")
fail_details = {
"status": "FAIL",
"reason": "Test failed",
"exit_code": exit_code,
"stderr": stderr,
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details
# If test passed, check for coverage increase
try:
# Step 4: Check that the coverage has increased using the CoverageProcessor class
new_coverage_processor = CoverageProcessor(
file_path=self.code_coverage_report_path,
filename=os.path.basename(self.source_file_path),
coverage_type=self.coverage_type,
)
_, _, new_percentage_covered = (
new_coverage_processor.process_coverage_report(
time_of_test_command=time_of_test_command
)
)
if new_percentage_covered <= self.current_coverage:
# Coverage has not increased, rollback the test by removing it from the test file
with open(self.test_file_path, "w") as test_file:
test_file.write(original_content)
self.logger.info("Test did not increase coverage. Rolling back.")
fail_details = {
"status": "FAIL",
"reason": "Coverage did not increase",
"exit_code": exit_code,
"stderr": stderr,
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details
except Exception as e:
# Handle errors gracefully
self.logger.error(f"Error during coverage verification: {e}")
# Optionally, roll back even in case of error
with open(self.test_file_path, "w") as test_file:
test_file.write(original_content)
fail_details = {
"status": "FAIL",
"reason": "Runtime error",
"exit_code": exit_code,
"stderr": stderr,
"stdout": stdout,
"test": generated_test,
}
self.failed_test_runs.append(
fail_details["test"]
) # Append failure details to the list
return fail_details
# If everything passed and coverage increased, update current coverage and log success
self.current_coverage = new_percentage_covered
self.logger.info(
f"Test passed and coverage increased. Current coverage: {round(new_percentage_covered * 100, 2)}%"
)
return {
"status": "PASS",
"reason": "",
"exit_code": exit_code,
"stderr": stderr,
"stdout": stdout,
"test": generated_test,
}