forked from balrog-ai/BALROG
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Few-Shot Learning Support to Balrog (balrog-ai#4)
* Few Shot Learning * turn off dummy actions * simplify dataset * add docs for few shot learning * update docs * fix download link * fix loading dataset * add parameter to limit the size of icl context * sample demonstrations randomly * quick fix * set default max_icl_history to 1000
- Loading branch information
1 parent
395ca9b
commit 67a8d26
Showing
9 changed files
with
326 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,5 +165,5 @@ cython_debug/ | |
/outputs | ||
/tw_games | ||
tw-games.zip | ||
/demos | ||
/demos.zip | ||
/records | ||
/records.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import copy | ||
import re | ||
from typing import List, Optional | ||
|
||
from balrog.agents.base import BaseAgent | ||
|
||
|
||
class Message: | ||
def __init__(self, role: str, content: str, attachment: Optional[object] = None): | ||
self.role = role # 'system', 'user', 'assistant' | ||
self.content = content # String content of the message | ||
self.attachment = attachment | ||
|
||
def __repr__(self): | ||
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})" | ||
|
||
|
||
class FewShotAgent(BaseAgent): | ||
def __init__(self, client_factory, prompt_builder, max_icl_history): | ||
"""Initialize the FewShotAgent with a client and prompt builder.""" | ||
super().__init__(client_factory, prompt_builder) | ||
self.client = client_factory() | ||
self.icl_episodes = [] | ||
self.icl_events = [] | ||
self.max_icl_history = max_icl_history | ||
self.cached_icl = False | ||
|
||
def update_icl_observation(self, obs: dict): | ||
long_term_context = obs["text"].get("long_term_context", "") | ||
self.icl_events.append( | ||
{ | ||
"type": "icl_observation", | ||
"text": long_term_context, | ||
} | ||
) | ||
|
||
def update_icl_action(self, action: str): | ||
self.icl_events.append( | ||
{ | ||
"type": "icl_action", | ||
"action": action, | ||
} | ||
) | ||
|
||
def cache_icl(self): | ||
self.client.cache_icl_demo(self.get_icl_prompt()) | ||
self.cached_icl = True | ||
|
||
def wrap_episode(self): | ||
icl_episode = [] | ||
icl_episode.append( | ||
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******") | ||
) | ||
for event in self.icl_events: | ||
if event["type"] == "icl_observation": | ||
content = "Obesrvation:\n" + event["text"] | ||
message = Message(role="user", content=content) | ||
elif event["type"] == "icl_action": | ||
content = event["action"] | ||
message = Message(role="assistant", content=content) | ||
icl_episode.append(message) | ||
icl_episode.append( | ||
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******") | ||
) | ||
|
||
self.icl_episodes.append(icl_episode) | ||
self.icl_events = [] | ||
|
||
def get_icl_prompt(self) -> List[Message]: | ||
icl_instruction = Message( | ||
role="user", | ||
content=self.prompt_builder.system_prompt.replace( | ||
"PLAY", | ||
"First, observe the demonstrations provided and learn from them!", | ||
), | ||
) | ||
|
||
# unroll the wrapped icl episodes messages | ||
icl_messages = [icl_instruction] | ||
i = 0 | ||
for icl_episode in self.icl_episodes: | ||
episode_steps = len(icl_episode) - 2 # not count start and end messages | ||
if i + episode_steps <= self.max_icl_history: | ||
icl_messages.extend(icl_episode) | ||
i += episode_steps | ||
else: | ||
icl_episode = icl_episode[: self.max_icl_history - i + 1] + [ | ||
icl_episode[-1] | ||
] # +1 for start message -1 for end message | ||
icl_messages.extend(icl_episode) | ||
i += len(icl_episode) - 2 # not count start and end messages | ||
break | ||
|
||
end_demo_message = Message( | ||
role="user", | ||
content="****** Now it's your turn to play the game! ******", | ||
) | ||
icl_messages.append(end_demo_message) | ||
|
||
return icl_messages | ||
|
||
def act(self, obs, prev_action=None): | ||
"""Generate the next action based on the observation and previous action. | ||
Args: | ||
obs (dict): The current observation in the environment. | ||
prev_action (str, optional): The previous action taken. | ||
Returns: | ||
str: The selected action from the LLM response. | ||
""" | ||
if prev_action: | ||
self.prompt_builder.update_action(prev_action) | ||
|
||
self.prompt_builder.update_observation(obs) | ||
|
||
if not self.cached_icl: | ||
messages = self.get_icl_prompt() | ||
else: | ||
messages = [] | ||
|
||
messages.extend(self.prompt_builder.get_prompt(icl_episodes=True)) | ||
|
||
naive_instruction = """ | ||
You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates. | ||
""".strip() | ||
|
||
if messages and messages[-1].role == "user": | ||
messages[-1].content += "\n\n" + naive_instruction | ||
|
||
response = self.client.generate(messages) | ||
|
||
final_answer = self._extract_final_answer(response) | ||
|
||
return final_answer | ||
|
||
def _extract_final_answer(self, answer): | ||
"""Sanitize the final answer, keeping only alphabetic characters. | ||
Args: | ||
answer (LLMResponse): The response from the LLM. | ||
Returns: | ||
LLMResponse: The sanitized response. | ||
""" | ||
|
||
def filter_letters(input_string): | ||
return re.sub(r"[^a-zA-Z\s:]", "", input_string) | ||
|
||
final_answer = copy.deepcopy(answer) | ||
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion)) | ||
|
||
return final_answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import glob | ||
import logging | ||
import os | ||
import random | ||
import re | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
||
|
||
def natural_sort_key(s): | ||
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))] | ||
|
||
|
||
def choice_excluding(lst, excluded_element): | ||
possible_choices = [item for item in lst if item != excluded_element] | ||
return random.choice(possible_choices) | ||
|
||
|
||
class InContextDataset: | ||
def __init__(self, config, env_name, original_cwd) -> None: | ||
self.config = config | ||
self.env_name = env_name | ||
self.original_cwd = original_cwd | ||
|
||
def icl_episodes(self, task): | ||
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task | ||
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key)) | ||
|
||
def extract_seed(self, demo_path): | ||
# extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz` | ||
seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part] | ||
return int(seed[0]) | ||
|
||
def demo_task(self, task): | ||
# use different task - avoid the case where we put the solution into the context | ||
if self.env_name == "babaisai": | ||
task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task) | ||
|
||
return task | ||
|
||
def demo_path(self, i, task): | ||
icl_episodes = self.icl_episodes(task) | ||
demo_path = icl_episodes[i % len(icl_episodes)] | ||
|
||
# use different seed - avoid the case where we put the solution into the context | ||
if self.env_name == "textworld": | ||
from balrog.environments.textworld import global_textworld_context | ||
|
||
textworld_context = global_textworld_context( | ||
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs | ||
) | ||
next_seed = textworld_context.count[task] | ||
demo_seed = self.extract_seed(demo_path) | ||
if next_seed == demo_seed: | ||
demo_path = self.icl_episodes(task)[i + 1] | ||
|
||
return demo_path | ||
|
||
def load_episode(self, filename): | ||
# Load the compressed NPZ file | ||
with np.load(filename, allow_pickle=True) as data: | ||
# Convert to dictionary if you want | ||
episode = {k: data[k] for k in data.files} | ||
return episode | ||
|
||
def load_in_context_learning_episodes(self, num_episodes, task, agent): | ||
demo_task = self.demo_task(task) | ||
demo_paths = [self.demo_path(i, demo_task) for i in range(len(self.icl_episodes(task)))] | ||
random.shuffle(demo_paths) | ||
demo_paths = demo_paths[:num_episodes] | ||
|
||
for demo_path in demo_paths: | ||
self.load_in_context_learning_episode(demo_path, agent) | ||
|
||
def load_in_context_learning_episode(self, demo_path, agent): | ||
episode = self.load_episode(demo_path) | ||
|
||
actions = episode.pop("action").tolist() | ||
rewards = episode.pop("reward").tolist() | ||
terminated = episode.pop("terminated") | ||
truncated = episode.pop("truncated") | ||
dones = np.any([terminated, truncated], axis=0).tolist() | ||
observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())] | ||
|
||
# first transition only contains observation (like env.reset()) | ||
observation, action, reward, done = observations.pop(0), actions.pop(0), rewards.pop(0), dones.pop(0) | ||
agent.update_icl_observation(observation) | ||
|
||
for observation, action, reward, done in zip(observations, actions, rewards, dones): | ||
action = str(action) | ||
|
||
agent.update_icl_action(action) | ||
agent.update_icl_observation(observation) | ||
|
||
if done: | ||
break | ||
|
||
if not done: | ||
logging.info("icl trajectory ended without done") | ||
|
||
agent.wrap_episode() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.