diff --git a/devchat/config.py b/devchat/config.py index 0cb7dd13..c2c34d59 100644 --- a/devchat/config.py +++ b/devchat/config.py @@ -8,7 +8,7 @@ from devchat.anthropic import AnthropicChatParameters -class ClientType(str, Enum): +class Client(str, Enum): OPENAI = "openai" ANTHROPIC = "anthropic" @@ -19,41 +19,36 @@ class OpenAIClientConfig(BaseModel, extra='forbid'): api_type: Optional[str] api_version: Optional[str] deployment_name: Optional[str] - # client-level parameters - parameters: Optional[OpenAIChatParameters] = OpenAIChatParameters() class AnthropicClientConfig(BaseModel, extra='forbid'): api_key: str api_base: Optional[str] timeout: Optional[float] - # client-level parameters - parameters: Optional[AnthropicChatParameters] = AnthropicChatParameters() class ProviderConfig(BaseModel, extra='forbid', use_enum_values=True): id: str - client_type: ClientType + client: Client client_config: Union[OpenAIClientConfig, AnthropicClientConfig] @validator('client_config', pre=True) def create_client_config(cls, config, values): # pylint: disable=E0213 if not isinstance(config, dict): return config - if 'client_type' in values: - if values['client_type'] == ClientType.OPENAI: + if 'client' in values: + if values['client'] == Client.OPENAI: return OpenAIClientConfig(**config) - if values['client_type'] == ClientType.ANTHROPIC: + if values['client'] == Client.ANTHROPIC: return AnthropicClientConfig(**config) - raise ValueError(f"Invalid client_type in {values}") + raise ValueError(f"Invalid client in {values}") class ModelConfig(BaseModel, extra='forbid'): id: str max_input_tokens: Optional[int] = sys.maxsize - # model-level parameters that override client-level parameters parameters: Optional[Union[OpenAIChatParameters, AnthropicChatParameters]] - provider_id: Optional[str] + provider: Optional[str] class ChatConfig(BaseModel, extra='forbid'): @@ -75,17 +70,20 @@ def _load_and_validate_config(self) -> ChatConfig: for index, provider in enumerate(config_data['providers']): config_data['providers'][index] = ProviderConfig(**provider) for model in config_data['models']: - if 'provider_id' not in model: - raise ValueError(f"Model in {self.config_path} is missing provider_id") + if 'provider' not in model: + raise ValueError(f"Model in {self.config_path} is missing provider") if 'parameters' in model: - providers = [p for p in config_data['providers'] if p.id == model['provider_id']] + providers = [p for p in config_data['providers'] if p.id == model['provider']] if len(providers) < 1: - raise ValueError(f"Model in {self.config_path} has invalid provider_id: " - f"{model['provider_id']}") - client = providers[0].client_config - parameters_data = client.parameters.dict(exclude_unset=True) - parameters_data.update(model['parameters']) - model['parameters'] = client.parameters.parse_obj(parameters_data) + raise ValueError(f"Model in {self.config_path} has invalid provider: " + f"{model['provider']}") + if providers[0].client == Client.OPENAI: + model['parameters'] = OpenAIChatParameters(**model['parameters']) + elif providers[0].client == Client.ANTHROPIC: + model['parameters'] = AnthropicChatParameters(**model['parameters']) + else: + raise ValueError(f"Model in {self.config_path} has invalid client: " + f"{providers[0].client}") return ChatConfig(**config_data) def model_config(self, model_id: Optional[str] = None) -> ModelConfig: @@ -127,17 +125,17 @@ def _create_sample_config(self): providers=[ ProviderConfig( id="devchat", - client_type=ClientType.OPENAI, + client=Client.OPENAI, client_config=OpenAIClientConfig(api_key="DC....SET.THIS") ), ProviderConfig( id="openai", - client_type=ClientType.OPENAI, + client=Client.OPENAI, client_config=OpenAIClientConfig(api_key="sk-...SET-THIS") ), ProviderConfig( id="azure-openai", - client_type=ClientType.OPENAI, + client=Client.OPENAI, client_config=OpenAIClientConfig( api_key="YOUR_AZURE_OPENAI_KEY", api_base="YOUR_AZURE_OPENAI_ENDPOINT", @@ -147,7 +145,7 @@ def _create_sample_config(self): ), ProviderConfig( id="anthropic", - client_type=ClientType.ANTHROPIC, + client=Client.ANTHROPIC, client_config=AnthropicClientConfig(api_key="sk-ant-...SET-THIS", timeout=30) ) ], @@ -156,24 +154,24 @@ def _create_sample_config(self): id="gpt-4", max_input_tokens=6000, parameters=OpenAIChatParameters(temperature=0, stream=True), - provider_id='devchat' + provider='devchat' ), ModelConfig( id="gpt-3.5-turbo-16k", max_input_tokens=12000, parameters=OpenAIChatParameters(temperature=0, stream=True), - provider_id='devchat' + provider='devchat' ), ModelConfig( id="gpt-3.5-turbo", max_input_tokens=3000, parameters=OpenAIChatParameters(temperature=0, stream=True), - provider_id='devchat' + provider='devchat' ), ModelConfig( id="claude-2", parameters=AnthropicChatParameters(max_tokens_to_sample=20000), - provider_id='anthropic' + provider='anthropic' ) ] ) diff --git a/tests/test_cli_prompt.py b/tests/test_cli_prompt.py index c831c826..45e033a9 100644 --- a/tests/test_cli_prompt.py +++ b/tests/test_cli_prompt.py @@ -28,7 +28,7 @@ def test_prompt_with_temp_config_file(mock_home_dir): config_data = f""" providers: - id: openai - client_type: openai + client: openai client_config: api_key: {os.environ['OPENAI_API_KEY']} models: @@ -36,7 +36,7 @@ def test_prompt_with_temp_config_file(mock_home_dir): max_input_tokens: 3000 parameters: temperature: 0 - provider_id: openai + provider: openai """ chat_dir = os.path.join(mock_home_dir, ".chat")