Skip to content

Commit

Permalink
Remove parameters from client in config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
basicthinker committed Sep 11, 2023
1 parent d8a145b commit 03703a2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
56 changes: 27 additions & 29 deletions devchat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devchat.anthropic import AnthropicChatParameters


class ClientType(str, Enum):
class Client(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"

Expand All @@ -19,41 +19,36 @@ class OpenAIClientConfig(BaseModel, extra='forbid'):
api_type: Optional[str]
api_version: Optional[str]
deployment_name: Optional[str]
# client-level parameters
parameters: Optional[OpenAIChatParameters] = OpenAIChatParameters()


class AnthropicClientConfig(BaseModel, extra='forbid'):
api_key: str
api_base: Optional[str]
timeout: Optional[float]
# client-level parameters
parameters: Optional[AnthropicChatParameters] = AnthropicChatParameters()


class ProviderConfig(BaseModel, extra='forbid', use_enum_values=True):
id: str
client_type: ClientType
client: Client
client_config: Union[OpenAIClientConfig, AnthropicClientConfig]

@validator('client_config', pre=True)
def create_client_config(cls, config, values): # pylint: disable=E0213
if not isinstance(config, dict):
return config
if 'client_type' in values:
if values['client_type'] == ClientType.OPENAI:
if 'client' in values:
if values['client'] == Client.OPENAI:
return OpenAIClientConfig(**config)
if values['client_type'] == ClientType.ANTHROPIC:
if values['client'] == Client.ANTHROPIC:
return AnthropicClientConfig(**config)
raise ValueError(f"Invalid client_type in {values}")
raise ValueError(f"Invalid client in {values}")


class ModelConfig(BaseModel, extra='forbid'):
id: str
max_input_tokens: Optional[int] = sys.maxsize
# model-level parameters that override client-level parameters
parameters: Optional[Union[OpenAIChatParameters, AnthropicChatParameters]]
provider_id: Optional[str]
provider: Optional[str]


class ChatConfig(BaseModel, extra='forbid'):
Expand All @@ -75,17 +70,20 @@ def _load_and_validate_config(self) -> ChatConfig:
for index, provider in enumerate(config_data['providers']):
config_data['providers'][index] = ProviderConfig(**provider)
for model in config_data['models']:
if 'provider_id' not in model:
raise ValueError(f"Model in {self.config_path} is missing provider_id")
if 'provider' not in model:
raise ValueError(f"Model in {self.config_path} is missing provider")
if 'parameters' in model:
providers = [p for p in config_data['providers'] if p.id == model['provider_id']]
providers = [p for p in config_data['providers'] if p.id == model['provider']]
if len(providers) < 1:
raise ValueError(f"Model in {self.config_path} has invalid provider_id: "
f"{model['provider_id']}")
client = providers[0].client_config
parameters_data = client.parameters.dict(exclude_unset=True)
parameters_data.update(model['parameters'])
model['parameters'] = client.parameters.parse_obj(parameters_data)
raise ValueError(f"Model in {self.config_path} has invalid provider: "
f"{model['provider']}")
if providers[0].client == Client.OPENAI:
model['parameters'] = OpenAIChatParameters(**model['parameters'])
elif providers[0].client == Client.ANTHROPIC:
model['parameters'] = AnthropicChatParameters(**model['parameters'])
else:
raise ValueError(f"Model in {self.config_path} has invalid client: "
f"{providers[0].client}")
return ChatConfig(**config_data)

def model_config(self, model_id: Optional[str] = None) -> ModelConfig:
Expand Down Expand Up @@ -127,17 +125,17 @@ def _create_sample_config(self):
providers=[
ProviderConfig(
id="devchat",
client_type=ClientType.OPENAI,
client=Client.OPENAI,
client_config=OpenAIClientConfig(api_key="DC....SET.THIS")
),
ProviderConfig(
id="openai",
client_type=ClientType.OPENAI,
client=Client.OPENAI,
client_config=OpenAIClientConfig(api_key="sk-...SET-THIS")
),
ProviderConfig(
id="azure-openai",
client_type=ClientType.OPENAI,
client=Client.OPENAI,
client_config=OpenAIClientConfig(
api_key="YOUR_AZURE_OPENAI_KEY",
api_base="YOUR_AZURE_OPENAI_ENDPOINT",
Expand All @@ -147,7 +145,7 @@ def _create_sample_config(self):
),
ProviderConfig(
id="anthropic",
client_type=ClientType.ANTHROPIC,
client=Client.ANTHROPIC,
client_config=AnthropicClientConfig(api_key="sk-ant-...SET-THIS", timeout=30)
)
],
Expand All @@ -156,24 +154,24 @@ def _create_sample_config(self):
id="gpt-4",
max_input_tokens=6000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider_id='devchat'
provider='devchat'
),
ModelConfig(
id="gpt-3.5-turbo-16k",
max_input_tokens=12000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider_id='devchat'
provider='devchat'
),
ModelConfig(
id="gpt-3.5-turbo",
max_input_tokens=3000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider_id='devchat'
provider='devchat'
),
ModelConfig(
id="claude-2",
parameters=AnthropicChatParameters(max_tokens_to_sample=20000),
provider_id='anthropic'
provider='anthropic'
)
]
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def test_prompt_with_temp_config_file(mock_home_dir):
config_data = f"""
providers:
- id: openai
client_type: openai
client: openai
client_config:
api_key: {os.environ['OPENAI_API_KEY']}
models:
- id: gpt-3.5-turbo
max_input_tokens: 3000
parameters:
temperature: 0
provider_id: openai
provider: openai
"""

chat_dir = os.path.join(mock_home_dir, ".chat")
Expand Down

0 comments on commit 03703a2

Please sign in to comment.