Skip to content

Commit

Permalink
Refactor response token counting in Prompt classes
Browse files Browse the repository at this point in the history
- Simplify '_count_response_tokens' in 'OpenAIPrompt'.
- Move '_response_tokens' check to 'response_tokens' in 'Prompt'.
- Update '_hash' in 'Prompt' to use refactored property.
  • Loading branch information
basicthinker committed Aug 1, 2023
1 parent 9410725 commit 2a9f255
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
7 changes: 1 addition & 6 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,7 @@ def append_response(self, delta_str: str) -> str:
return delta_content

def _count_response_tokens(self) -> int:
if self._response_tokens:
return self._response_tokens

total = sum(openai_response_tokens(resp.to_dict(), self.model) for resp in self.responses)
self._response_tokens = total
return total
return sum(openai_response_tokens(resp.to_dict(), self.model) for resp in self.responses)

def _validate_model(self, response_data: dict):
if not response_data['model'].startswith(self.model):
Expand Down
4 changes: 3 additions & 1 deletion devchat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def request_tokens(self) -> int:

@property
def response_tokens(self) -> int:
if not self._response_tokens:
self._response_tokens = self._count_response_tokens()
return self._response_tokens

@abstractmethod
Expand Down Expand Up @@ -189,7 +191,7 @@ def finalize_hash(self) -> str:
if self._hash:
return self._hash

self._count_response_tokens()
self._response_tokens = self._count_response_tokens()

data = asdict(self)
data.pop('_hash')
Expand Down
3 changes: 2 additions & 1 deletion devchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def openai_message_tokens(message: dict, model: str) -> int:
for key, value in message.items():
if key == 'function_call':
value = json.dumps(value)
num_tokens += len(encoding.encode(value))
if value:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
return num_tokens
Expand Down
7 changes: 5 additions & 2 deletions tests/test_cli_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_prompt_with_content(git_repo): # pylint: disable=W0613

def test_prompt_with_temp_config_file(git_repo):
config_data = {
'model': 'gpt-3.5-turbo-0301',
'model': 'gpt-3.5-turbo',
'provider': 'OpenAI',
'tokens-per-prompt': 3000,
'OpenAI': {
Expand Down Expand Up @@ -110,7 +110,8 @@ def test_prompt_with_functions(git_repo, functions_file): # pylint: disable=W06
# call with -f option
result = runner.invoke(main, ['prompt', '-m', 'gpt-3.5-turbo', '-f', functions_file,
"What is the weather like in Boston?"])
print(result.output)
if result.exit_code:
print(result.output)
assert result.exit_code == 0
content = get_content(result.output)
assert 'finish_reason: function_call' in content
Expand All @@ -131,6 +132,8 @@ def test_prompt_log_with_functions(git_repo, functions_file): # pylint: disable
# call with -f option
result = runner.invoke(main, ['prompt', '-m', 'gpt-3.5-turbo', '-f', functions_file,
'What is the weather like in Boston?'])
if result.exit_code:
print(result.output)
assert result.exit_code == 0
prompt_hash = get_prompt_hash(result.output)
result = runner.invoke(main, ['log', '-t', prompt_hash])
Expand Down

0 comments on commit 2a9f255

Please sign in to comment.