Skip to content

Commit

Permalink
Use directory of models in config
Browse files Browse the repository at this point in the history
- Update `get_model_config` to return a tuple of model and config
- Update `ModelConfig` and `ChatConfig` to handle models as a dictionary
- Update `update_model_config` in `ConfigManager` to accept model id
- Update tests in `tests/test_cli_prompt.py` and `tests/test_config.py`
  • Loading branch information
basicthinker committed Sep 11, 2023
1 parent dba1e33 commit 72ce1cc
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 89 deletions.
7 changes: 4 additions & 3 deletions devchat/_cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig
from devchat.store import Store
from devchat.utils import get_logger
from devchat._cli.utils import handle_errors, init_dir, model_config
from devchat._cli.utils import handle_errors, init_dir, get_model_config

logger = get_logger(__name__)

Expand All @@ -26,9 +26,10 @@ def log(skip, max_count, topic_root, delete):
repo_chat_dir, user_chat_dir = init_dir()

with handle_errors():
config = model_config(repo_chat_dir, user_chat_dir)
model, config = get_model_config(repo_chat_dir, user_chat_dir)
parameters_data = config.parameters.dict(exclude_unset=True) if config.parameters else {}
openai_config = OpenAIChatConfig(model=config.id, **parameters_data)
openai_config = OpenAIChatConfig(model=model, **parameters_data)

chat = OpenAIChat(openai_config)
store = Store(repo_chat_dir, chat)

Expand Down
10 changes: 3 additions & 7 deletions devchat/_cli/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig
from devchat.store import Store
from devchat.utils import parse_files
from devchat._cli.utils import handle_errors, init_dir, model_config
from devchat._cli.utils import handle_errors, init_dir, get_model_config


@click.command()
Expand Down Expand Up @@ -79,17 +79,13 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
instruct_contents = parse_files(instruct)
context_contents = parse_files(context)

config = model_config(repo_chat_dir, user_chat_dir, model)
if not model:
model = config.id
model, config = get_model_config(repo_chat_dir, user_chat_dir, model)

parameters_data = config.parameters.dict(exclude_unset=True) if config.parameters else {}
if config_str:
config_data = json.loads(config_str)
parameters_data.update(config_data)

openai_config = OpenAIChatConfig(model=model,
**parameters_data)
openai_config = OpenAIChatConfig(model=model, **parameters_data)

chat = OpenAIChat(openai_config)
store = Store(repo_chat_dir, chat)
Expand Down
7 changes: 4 additions & 3 deletions devchat/_cli/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from devchat.store import Store
from devchat.openai import OpenAIChatConfig, OpenAIChat
from devchat.utils import get_logger
from devchat._cli.utils import init_dir, handle_errors, model_config
from devchat._cli.utils import init_dir, handle_errors, get_model_config

logger = get_logger(__name__)

Expand All @@ -20,9 +20,10 @@ def topic(list_topics: bool, skip: int, max_count: int):
repo_chat_dir, user_chat_dir = init_dir()

with handle_errors():
config = model_config(repo_chat_dir, user_chat_dir)
model, config = get_model_config(repo_chat_dir, user_chat_dir)
parameters_data = config.parameters.dict(exclude_unset=True) if config.parameters else {}
openai_config = OpenAIChatConfig(model=config.id, **parameters_data)
openai_config = OpenAIChatConfig(model=model, **parameters_data)

chat = OpenAIChat(openai_config)
store = Store(repo_chat_dir, chat)

Expand Down
8 changes: 4 additions & 4 deletions devchat/_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def clone_git_repo(target_dir: str, repo_urls: List[str]):
raise GitCommandError(f"Failed to clone repository to {target_dir}")


def model_config(repo_chat_dir: str, user_chat_dir: str,
model: Optional[str] = None) -> ModelConfig:
def get_model_config(repo_chat_dir: str, user_chat_dir: str,
model: Optional[str] = None) -> Tuple[str, ModelConfig]:
legacy_path = os.path.join(repo_chat_dir, 'config.json')
if os.path.exists(legacy_path):
os.rename(legacy_path, legacy_path + '.old')

config = ConfigManager(user_chat_dir)
return config.model_config(model)
config_manager = ConfigManager(user_chat_dir)
return config_manager.model_config(model)
98 changes: 42 additions & 56 deletions devchat/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
import os
import sys
from typing import Dict, List, Union, Optional
from typing import Dict, Tuple, Union, Optional
from pydantic import BaseModel
import yaml
from devchat.openai import OpenAIChatParameters
Expand Down Expand Up @@ -32,15 +32,14 @@ class AnthropicProviderConfig(ProviderConfig, extra='forbid'):


class ModelConfig(BaseModel, extra='forbid'):
id: str
max_input_tokens: Optional[int] = sys.maxsize
parameters: Optional[Union[OpenAIChatParameters, AnthropicChatParameters]]
provider: Optional[str]


class ChatConfig(BaseModel, extra='forbid'):
providers: Dict[str, ProviderConfig]
models: List[ModelConfig]
models: Dict[str, ModelConfig]
default_model: Optional[str]


Expand All @@ -53,58 +52,48 @@ def __init__(self, dir_path: str):

def _load_and_validate_config(self) -> ChatConfig:
with open(self.config_path, 'r', encoding='utf-8') as file:
config_data = yaml.safe_load(file)
for provider, config in config_data['providers'].items():
data = yaml.safe_load(file)
for provider, config in data['providers'].items():
if config['client'] == Client.OPENAI:
config_data['providers'][provider] = OpenAIProviderConfig(**config)
data['providers'][provider] = OpenAIProviderConfig(**config)
elif config['client'] == Client.ANTHROPIC:
config_data['providers'][provider] = AnthropicProviderConfig(**config)
data['providers'][provider] = AnthropicProviderConfig(**config)
else:
raise ValueError(f"Provider {provider} in {self.config_path} has invalid client: "
raise ValueError(f"Provider '{provider}' in {self.config_path} has invalid client: "
f"{config.client}")
for model in config_data['models']:
if 'provider' not in model:
raise ValueError(f"Model in {self.config_path} is missing provider")
if 'parameters' in model:
provider_config = config_data['providers'][model['provider']]
if provider_config.client == Client.OPENAI:
model['parameters'] = OpenAIChatParameters(**model['parameters'])
elif provider_config.client == Client.ANTHROPIC:
model['parameters'] = AnthropicChatParameters(**model['parameters'])
for model, config in data['models'].items():
if 'provider' not in config:
raise ValueError(f"Model '{model}' in {self.config_path} is missing provider")
if 'parameters' in config:
provider = data['providers'][config['provider']]
if provider.client == Client.OPENAI:
config['parameters'] = OpenAIChatParameters(**config['parameters'])
elif provider.client == Client.ANTHROPIC:
config['parameters'] = AnthropicChatParameters(**config['parameters'])
else:
raise ValueError(f"Model in {self.config_path} has invalid client: "
f"{provider_config.client}")
return ChatConfig(**config_data)
raise ValueError(f"Model '{model}' in {self.config_path} has invalid provider")
return ChatConfig(**data)

def model_config(self, model_id: Optional[str] = None) -> ModelConfig:
def model_config(self, model_id: Optional[str] = None) -> Tuple[str, ModelConfig]:
if not model_id:
if not self.config.models:
raise ValueError(f"No models found in {self.config_path}")
if self.config.default_model:
return self.model_config(self.config.default_model)
return self.config.models[0]
for model in self.config.models:
if model.id == model_id:
return model
raise ValueError(f"Model {model_id} not found in {self.config_path}")

def provider_config(self, provider_id: str) -> ProviderConfig:
for provider in self.config.providers:
if provider.id == provider_id:
return provider
raise ValueError(f"Provider {provider_id} not found in {self.config_path}")

def update_model_config(self, model_config: ModelConfig) -> ModelConfig:
model = self.model_config(model_config.id)
if not model:
return None
if model_config.max_input_tokens is not None:
model.max_input_tokens = model_config.max_input_tokens
if model_config.parameters is not None:
updated_parameters = model.parameters.dict(exclude_unset=True)
updated_parameters.update(model_config.parameters.dict(exclude_unset=True))
model.parameters = OpenAIChatParameters(**updated_parameters)
return model
if self.config.models:
return next(iter(self.config.models.items()))
raise ValueError(f"No models found in {self.config_path}")
if model_id not in self.config.models:
raise ValueError(f"Model '{model_id}' not found in {self.config_path}")
return model_id, self.config.models[model_id]

def update_model_config(self, model_id: str, new_config: ModelConfig) -> ModelConfig:
_, old_config = self.model_config(model_id)
if new_config.max_input_tokens is not None:
old_config.max_input_tokens = new_config.max_input_tokens
if new_config.parameters is not None:
updated_parameters = old_config.parameters.dict(exclude_unset=True)
updated_parameters.update(new_config.parameters.dict(exclude_unset=True))
old_config.parameters = type(new_config.parameters)(**updated_parameters)
return old_config

def sync(self):
with open(self.config_path, 'w', encoding='utf-8') as file:
Expand Down Expand Up @@ -134,31 +123,28 @@ def _create_sample_config(self):
timeout=30
)
},
models=[
ModelConfig(
id="gpt-4",
models={
"gpt-4": ModelConfig(
max_input_tokens=6000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider='devchat.ai'
),
ModelConfig(
id="gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k": ModelConfig(
max_input_tokens=12000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider='devchat.ai'
),
ModelConfig(
id="gpt-3.5-turbo",
"gpt-3.5-turbo": ModelConfig(
max_input_tokens=3000,
parameters=OpenAIChatParameters(temperature=0, stream=True),
provider='devchat.ai'
),
ModelConfig(
id="claude-2",
"claude-2": ModelConfig(
parameters=AnthropicChatParameters(max_tokens_to_sample=20000),
provider='anthropic.com'
)
]
},
default_model="gpt-3.5-turbo"
)
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(sample_config.dict(exclude_unset=True), file)
12 changes: 5 additions & 7 deletions tests/test_cli_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_prompt_with_temp_config_file(mock_home_dir):
client: openai
api_key: {os.environ['OPENAI_API_KEY']}
models:
- id: gpt-3.5-turbo
gpt-3.5-turbo:
max_input_tokens: 3000
parameters:
temperature: 0
Expand Down Expand Up @@ -198,16 +198,15 @@ def test_prompt_without_repo(mock_home_dir): # pylint: disable=W0613
def test_prompt_tokens_exceed_config(mock_home_dir): # pylint: disable=W0613
model = "gpt-3.5-turbo"
max_input_tokens = 2000
model_config = ModelConfig(
id=model,
config = ModelConfig(
max_input_tokens=max_input_tokens,
parameters=OpenAIChatParameters(temperature=0)
)

chat_dir = os.path.join(mock_home_dir, ".chat")
os.makedirs(chat_dir)
config_manager = ConfigManager(chat_dir)
config_manager.update_model_config(model_config)
config_manager.update_model_config(model, config)
config_manager.config.default_model = model
config_manager.sync()

Expand All @@ -223,16 +222,15 @@ def test_prompt_tokens_exceed_config(mock_home_dir): # pylint: disable=W0613
def test_file_tokens_exceed_config(mock_home_dir, tmpdir): # pylint: disable=W0613
model = "gpt-3.5-turbo"
max_input_tokens = 2000
model_config = ModelConfig(
id=model,
config = ModelConfig(
max_input_tokens=max_input_tokens,
parameters=OpenAIChatParameters(temperature=0)
)

chat_dir = os.path.join(mock_home_dir, ".chat")
os.makedirs(chat_dir)
config_manager = ConfigManager(chat_dir)
config_manager.update_model_config(model_config)
config_manager.update_model_config(model, config)
config_manager.config.default_model = model
config_manager.sync()

Expand Down
17 changes: 8 additions & 9 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,21 @@ def test_load_and_validate_config(tmp_path):

def test_get_model_config(tmp_path):
config_manager = ConfigManager(tmp_path)
model_config = config_manager.model_config('gpt-4')
assert model_config.id == 'gpt-4'
assert model_config.max_input_tokens == 6000
assert model_config.parameters.temperature == 0
assert model_config.parameters.stream is True
_, config = config_manager.model_config('gpt-4')
assert config.max_input_tokens == 6000
assert config.parameters.temperature == 0
assert config.parameters.stream is True


def test_update_model_config(tmp_path):
model = 'gpt-4'
config_manager = ConfigManager(tmp_path)
new_model_config = ModelConfig(
id='gpt-4',
config = ModelConfig(
max_input_tokens=7000,
parameters=OpenAIChatParameters(temperature=0.5)
)
updated_model_config = config_manager.update_model_config(new_model_config)
assert updated_model_config == config_manager.model_config('gpt-4')
updated_model_config = config_manager.update_model_config(model, config)
assert updated_model_config == config_manager.model_config(model)[1]
assert updated_model_config.max_input_tokens == 7000
assert updated_model_config.parameters.temperature == 0.5
assert updated_model_config.parameters.stream is True
Expand Down

0 comments on commit 72ce1cc

Please sign in to comment.