Skip to content

Commit

Permalink
Added new unit tests for the 'ConfigManager' class
Browse files Browse the repository at this point in the history
- In tests, used tmp_path directly instead of converting it to a string.
- Used the built-in 'exclude_none' option in the 'dict' method.
  • Loading branch information
basicthinker committed Sep 5, 2023
1 parent db2c472 commit 61c0c91
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 17 deletions.
7 changes: 3 additions & 4 deletions devchat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _create_sample_config(self):
)
])
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(sample_config.dict(), file)
yaml.dump(sample_config.dict(exclude_none=True), file)

def _load_and_validate_config(self) -> ChatConfig:
with open(self.config_path, 'r', encoding='utf-8') as file:
Expand All @@ -65,8 +65,7 @@ def update_model_config(self, model_config: ModelConfig) -> ModelConfig:
if model_config.max_input_tokens is not None:
model.max_input_tokens = model_config.max_input_tokens
if model_config.parameters is not None:
updated_parameters = model.parameters.dict()
updated_parameters.update(
{k: v for k, v in model_config.parameters.dict().items() if v is not None})
updated_parameters = model.parameters.dict(exclude_none=True)
updated_parameters.update(model_config.parameters.dict(exclude_none=True))
model.parameters = OpenAIChatParameters(**updated_parameters)
return model
10 changes: 2 additions & 8 deletions devchat/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ def load_prompt(self, data: dict) -> OpenAIPrompt:

def complete_response(self, prompt: OpenAIPrompt) -> str:
# Filter the config parameters with non-None values
config_params = {
key: value
for key, value in self.config.dict().items() if value is not None
}
config_params = self.config.dict(exclude_none=True).items()
if prompt.get_functions():
config_params['functions'] = prompt.get_functions()
config_params['function_call'] = 'auto'
Expand All @@ -83,10 +80,7 @@ def complete_response(self, prompt: OpenAIPrompt) -> str:

def stream_response(self, prompt: OpenAIPrompt) -> Iterator:
# Filter the config parameters with non-None values
config_params = {
key: value
for key, value in self.config.dict().items() if value is not None
}
config_params = self.config.dict(exclude_none=True)
if prompt.get_functions():
config_params['functions'] = prompt.get_functions()
config_params['function_call'] = 'auto'
Expand Down
2 changes: 1 addition & 1 deletion tests/test_command_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_parse_command():

def test_command_parser(tmp_path):
# Create a Namespace instance with the temporary directory as the root path
namespace = Namespace(str(tmp_path))
namespace = Namespace(tmp_path)
command_parser = CommandParser(namespace)

# Test with a valid configuration file with most fields filled
Expand Down
34 changes: 34 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from devchat.config import ConfigManager, ModelConfig, ChatConfig, OpenAIChatParameters


def test_create_sample_config(tmp_path):
ConfigManager(tmp_path)
assert os.path.exists(os.path.join(tmp_path, 'config.yml'))


def test_load_and_validate_config(tmp_path):
config_manager = ConfigManager(tmp_path)
assert isinstance(config_manager.config, ChatConfig)


def test_get_model_config(tmp_path):
config_manager = ConfigManager(tmp_path)
model_config = config_manager.get_model_config('gpt-4')
assert model_config.id == 'gpt-4'
assert model_config.max_input_tokens == 6000
assert model_config.parameters.temperature == 0
assert model_config.parameters.stream is True


def test_update_model_config(tmp_path):
config_manager = ConfigManager(tmp_path)
new_model_config = ModelConfig(
id='gpt-4',
max_input_tokens=7000,
parameters=OpenAIChatParameters(temperature=0.5)
)
updated_model_config = config_manager.update_model_config(new_model_config)
assert updated_model_config.max_input_tokens == 7000
assert updated_model_config.parameters.temperature == 0.5
assert updated_model_config.parameters.stream is True
6 changes: 3 additions & 3 deletions tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_is_valid_name():

def test_get_file(tmp_path):
# Create a Namespace instance with the temporary directory as the root path
namespace = Namespace(str(tmp_path))
namespace = Namespace(tmp_path)

# Test case 1: a file that exists
# Create a file in the 'usr' branch
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_get_file(tmp_path):

def test_list_files(tmp_path):
# Create a Namespace instance with the temporary directory as the root path
namespace = Namespace(str(tmp_path))
namespace = Namespace(tmp_path)

# Test case 1: a path that exists
# Create a file in the 'usr' branch
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_list_names(tmp_path):
os.makedirs(os.path.join(tmp_path, 'org', 'a', 'b', 'd'))
os.makedirs(os.path.join(tmp_path, 'sys', 'a', 'e'))

namespace = Namespace(str(tmp_path))
namespace = Namespace(tmp_path)

# Test listing child commands
commands = namespace.list_names('a')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def test_prompter(tmp_path):
namespace = Namespace(str(tmp_path))
namespace = Namespace(tmp_path)
prompter = RecursivePrompter(namespace)

# Test when there are no 'prompt.txt' files
Expand Down

0 comments on commit 61c0c91

Please sign in to comment.