Skip to content

Commit

Permalink
fix: Implement self-correction for invalid LLM responses
Browse files Browse the repository at this point in the history
- Fix the parsing of invalid LLM responses by appending an error message to the prompt and allowing the LLM to fix its mistakes.
- Update the `OpenAIProvider` to handle the self-correction process and limit the number of attempts to fix parsing errors.
- Update the `BaseAgent` to profit from the new pasing and parse-fixing mechanism.

This change ensures that the system can handle and recover from errors in parsing LLM responses.

Hopefully this fixes Significant-Gravitas#1407 once and for all.
  • Loading branch information
Pwuts committed Dec 13, 2023
1 parent 6b0d0d4 commit acf4df9
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 46 deletions.
19 changes: 13 additions & 6 deletions autogpts/autogpt/autogpt/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
AssistantChatMessageDict,
ChatMessage,
ChatModelProvider,
ChatModelResponse,
)
from autogpt.llm.api_manager import ApiManager
from autogpt.logs.log_cycle import (
Expand All @@ -44,7 +44,12 @@
OneShotAgentPromptConfiguration,
OneShotAgentPromptStrategy,
)
from .utils.exceptions import AgentException, CommandExecutionError, UnknownCommandError
from .utils.exceptions import (
AgentException,
AgentTerminated,
CommandExecutionError,
UnknownCommandError,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,6 +81,8 @@ class Agent(
description=__doc__,
)

prompt_strategy: OneShotAgentPromptStrategy

def __init__(
self,
settings: AgentSettings,
Expand Down Expand Up @@ -164,20 +171,20 @@ def on_before_think(self, *args, **kwargs) -> ChatPrompt:
return prompt

def parse_and_process_response(
self, llm_response: ChatModelResponse, *args, **kwargs
self, llm_response: AssistantChatMessageDict, *args, **kwargs
) -> Agent.ThoughtProcessOutput:
for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
llm_response.response["content"] = plugin.post_planning(
llm_response.response.get("content", "")
llm_response["content"] = plugin.post_planning(
llm_response.get("content", "")
)

(
command_name,
arguments,
assistant_reply_dict,
) = self.prompt_strategy.parse_response_content(llm_response.response)
) = self.prompt_strategy.parse_response_content(llm_response)

self.log_cycle_handler.log_cycle(
self.ai_profile.ai_name,
Expand Down
18 changes: 10 additions & 8 deletions autogpts/autogpt/autogpt/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from autogpt.config import Config
from autogpt.core.prompting.base import PromptStrategy
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessageDict,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
Expand Down Expand Up @@ -247,7 +248,7 @@ async def propose_action(self) -> ThoughtProcessOutput:
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)

logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
raw_response = await self.llm_provider.create_chat_completion(
response = await self.llm_provider.create_chat_completion(
prompt.messages,
functions=get_openai_command_specs(
self.command_registry.list_available_commands(self)
Expand All @@ -256,11 +257,16 @@ async def propose_action(self) -> ThoughtProcessOutput:
if self.config.use_functions_api
else [],
model_name=self.llm.name,
completion_parser=lambda r: self.parse_and_process_response(
r,
prompt,
scratchpad=self._prompt_scratchpad,
),
)
self.config.cycle_count += 1

return self.on_response(
llm_response=raw_response,
llm_response=response,
prompt=prompt,
scratchpad=self._prompt_scratchpad,
)
Expand Down Expand Up @@ -397,18 +403,14 @@ def on_response(
The parsed command name and command args, if any, and the agent thoughts.
"""

return self.parse_and_process_response(
llm_response,
prompt,
scratchpad=scratchpad,
)
return llm_response.parsed_result

# TODO: update memory/context

@abstractmethod
def parse_and_process_response(
self,
llm_response: ChatModelResponse,
llm_response: AssistantChatMessageDict,
prompt: ChatPrompt,
scratchpad: PromptScratchpad,
) -> ThoughtProcessOutput:
Expand Down
52 changes: 37 additions & 15 deletions autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class OpenAIModelName(str, enum.Enum):


class OpenAIConfiguration(ModelProviderConfiguration):
fix_failed_parse_tries: int = UserConfigurable(3)
pass


Expand Down Expand Up @@ -363,24 +364,45 @@ async def create_chat_completion(
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]

response = await self._create_chat_completion(
messages=model_prompt,
**completion_kwargs,
)
response_args = {
"model_info": OPEN_AI_CHAT_MODELS[model_name],
"prompt_tokens_used": response.usage.prompt_tokens,
"completion_tokens_used": response.usage.completion_tokens,
}

response_message = response.choices[0].message.to_dict_recursive()
if tool_calls_compat_mode:
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
response_message["content"]
attempts = 0
while True:
response = await self._create_chat_completion(
messages=model_prompt,
**completion_kwargs,
)
response_args = {
"model_info": OPEN_AI_CHAT_MODELS[model_name],
"prompt_tokens_used": response.usage.prompt_tokens,
"completion_tokens_used": response.usage.completion_tokens,
}

response_message = response.choices[0].message.to_dict_recursive()
if tool_calls_compat_mode:
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
response_message["content"]
)

# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
try:
attempts += 1
parsed_response = completion_parser(response_message)
break
except Exception as e:
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
self._logger.debug(
f"Parsing failed on response: '''{response_message}'''"
)
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
)
else:
raise

response = ChatModelResponse(
response=response_message,
parsed_result=completion_parser(response_message),
parsed_result=parsed_response,
**response_args,
)
self._budget.update_usage_and_cost(response)
Expand Down
10 changes: 0 additions & 10 deletions autogpts/autogpt/autogpt/core/utils/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,8 @@ def validate_object(
validator = Draft7Validator(self.to_dict())

if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
for error in errors:
logger.debug(f"JSON Validation Error: {error}")

logger.error(json.dumps(object, indent=4))
logger.error("The following issues were found:")

for error in errors:
logger.error(f"Error: {error.message}")
return False, errors

logger.debug("The JSON object is valid.")

return True, None

def to_typescript_object_interface(self, interface_name: str = "") -> str:
Expand Down
14 changes: 7 additions & 7 deletions autogpts/autogpt/autogpt/json_utils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def extract_dict_from_response(response_content: str) -> dict[str, Any]:

# Response content comes from OpenAI as a Python `str(content_dict)`.
# `literal_eval` does the reverse of `str(dict)`.
try:
return ast.literal_eval(response_content)
except BaseException as e:
logger.info(f"Error parsing JSON response with literal_eval {e}")
logger.debug(f"Invalid JSON received in response:\n{response_content}")
# TODO: How to raise an error here without causing the program to exit?
return {}
result = ast.literal_eval(response_content)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
f"non-dict value {repr(result)}"
)
return result

0 comments on commit acf4df9

Please sign in to comment.