forked from balrog-ai/BALROG
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: more robust naive agent (balrog-ai#13)
* feat: robust naive agent * docs
- Loading branch information
1 parent
df38b8a
commit 9c65dad
Showing
3 changed files
with
74 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import copy | ||
import re | ||
|
||
from balrog.agents.base import BaseAgent | ||
|
||
|
||
class RobustNaiveAgent(BaseAgent): | ||
"""An agent that generates actions based on observations without complex reasoning.""" | ||
|
||
def __init__(self, client_factory, prompt_builder): | ||
"""Initialize the NaiveAgent with a client and prompt builder.""" | ||
super().__init__(client_factory, prompt_builder) | ||
self.client = client_factory() | ||
|
||
def act(self, obs, prev_action=None): | ||
"""Generate the next action based on the observation and previous action. | ||
Args: | ||
obs (dict): The current observation in the environment. | ||
prev_action (str, optional): The previous action taken. | ||
Returns: | ||
str: The selected action from the LLM response. | ||
""" | ||
if prev_action: | ||
self.prompt_builder.update_action(prev_action) | ||
|
||
self.prompt_builder.update_observation(obs) | ||
|
||
messages = self.prompt_builder.get_prompt() | ||
|
||
# Updated instructions to require a very strict output format | ||
naive_instruction = """ | ||
You must choose exactly one of the listed actions and output it strictly in the following format: | ||
<|ACTION|>YOUR_CHOSEN_ACTION</|ACTION|> | ||
You must not output any other text before or after these tags. No explanation, no reasoning, just the action within these tags. | ||
""".strip() | ||
|
||
if messages and messages[-1].role == "user": | ||
messages[-1].content += "\n\n" + naive_instruction | ||
|
||
response = self.client.generate(messages) | ||
final_answer = self._extract_final_answer(response) | ||
return final_answer | ||
|
||
def _extract_final_answer(self, answer): | ||
"""Extract the action from the completion by looking for <|ACTION|> ... </|ACTION|> tags. | ||
Args: | ||
answer (LLMResponse): The response from the LLM. | ||
Returns: | ||
LLMResponse: The sanitized response containing just the extracted action. | ||
""" | ||
completion_text = answer.completion | ||
# Use a regex to find the text inside <|ACTION|> and </|ACTION|> | ||
match = re.search(r"<\|ACTION\|>(.*?)</\|ACTION\|>", completion_text, re.DOTALL) | ||
if match: | ||
extracted_action = match.group(1).strip() | ||
else: | ||
# If no match is found, fallback to the original completion (or handle error) | ||
extracted_action = completion_text.strip() | ||
|
||
final_answer = copy.deepcopy(answer) | ||
final_answer = final_answer._replace(completion=extracted_action) | ||
|
||
return final_answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters