Skip to content

Commit

Permalink
Add Few-Shot Learning Support to Balrog (balrog-ai#4)
Browse files Browse the repository at this point in the history
* Few Shot Learning

* turn off dummy actions

* simplify dataset

* add docs for few shot learning

* update docs

* fix download link

* fix loading dataset

* add parameter to limit the size of icl context

* sample demonstrations randomly

* quick fix

* set default max_icl_history to 1000
  • Loading branch information
BartekCupial authored Dec 5, 2024
1 parent 395ca9b commit 67a8d26
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,5 @@ cython_debug/
/outputs
/tw_games
tw-games.zip
/demos
/demos.zip
/records
/records.zip
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ python eval.py \
## Documentation
- [Evaluation Guide](https://github.com/balrog-ai/BALROG/blob/main/docs/evaluation.md) - Detailed instructions for various evaluation scenarios
- [Agent Development](https://github.com/balrog-ai/BALROG/blob/main/docs/agents.md) - Tutorial on creating custom agents
- [Few Shot Learning](https://github.com/balrog-ai/BALROG/blob/main/docs/few_shot_learning.md) - Instructions on how to run Few Shot Learning

We welcome contributions! Please see our [Contributing Guidelines](https://github.com/balrog-ai/BALROG/blob/main/docs/contribution.md) for details.

Expand Down
3 changes: 3 additions & 0 deletions balrog/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .chain_of_thought import ChainOfThoughtAgent
from .custom import CustomAgent
from .dummy import DummyAgent
from .few_shot import FewShotAgent
from .naive import NaiveAgent


Expand Down Expand Up @@ -47,6 +48,8 @@ def create_agent(self):
return DummyAgent(client_factory, prompt_builder)
elif self.config.agent.type == "custom":
return CustomAgent(client_factory, prompt_builder)
elif self.config.agent.type == "few_shot":
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)

else:
raise ValueError(f"Unknown agent type: {self.config.agent}")
153 changes: 153 additions & 0 deletions balrog/agents/few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import copy
import re
from typing import List, Optional

from balrog.agents.base import BaseAgent


class Message:
def __init__(self, role: str, content: str, attachment: Optional[object] = None):
self.role = role # 'system', 'user', 'assistant'
self.content = content # String content of the message
self.attachment = attachment

def __repr__(self):
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})"


class FewShotAgent(BaseAgent):
def __init__(self, client_factory, prompt_builder, max_icl_history):
"""Initialize the FewShotAgent with a client and prompt builder."""
super().__init__(client_factory, prompt_builder)
self.client = client_factory()
self.icl_episodes = []
self.icl_events = []
self.max_icl_history = max_icl_history
self.cached_icl = False

def update_icl_observation(self, obs: dict):
long_term_context = obs["text"].get("long_term_context", "")
self.icl_events.append(
{
"type": "icl_observation",
"text": long_term_context,
}
)

def update_icl_action(self, action: str):
self.icl_events.append(
{
"type": "icl_action",
"action": action,
}
)

def cache_icl(self):
self.client.cache_icl_demo(self.get_icl_prompt())
self.cached_icl = True

def wrap_episode(self):
icl_episode = []
icl_episode.append(
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)
for event in self.icl_events:
if event["type"] == "icl_observation":
content = "Obesrvation:\n" + event["text"]
message = Message(role="user", content=content)
elif event["type"] == "icl_action":
content = event["action"]
message = Message(role="assistant", content=content)
icl_episode.append(message)
icl_episode.append(
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)

self.icl_episodes.append(icl_episode)
self.icl_events = []

def get_icl_prompt(self) -> List[Message]:
icl_instruction = Message(
role="user",
content=self.prompt_builder.system_prompt.replace(
"PLAY",
"First, observe the demonstrations provided and learn from them!",
),
)

# unroll the wrapped icl episodes messages
icl_messages = [icl_instruction]
i = 0
for icl_episode in self.icl_episodes:
episode_steps = len(icl_episode) - 2 # not count start and end messages
if i + episode_steps <= self.max_icl_history:
icl_messages.extend(icl_episode)
i += episode_steps
else:
icl_episode = icl_episode[: self.max_icl_history - i + 1] + [
icl_episode[-1]
] # +1 for start message -1 for end message
icl_messages.extend(icl_episode)
i += len(icl_episode) - 2 # not count start and end messages
break

end_demo_message = Message(
role="user",
content="****** Now it's your turn to play the game! ******",
)
icl_messages.append(end_demo_message)

return icl_messages

def act(self, obs, prev_action=None):
"""Generate the next action based on the observation and previous action.
Args:
obs (dict): The current observation in the environment.
prev_action (str, optional): The previous action taken.
Returns:
str: The selected action from the LLM response.
"""
if prev_action:
self.prompt_builder.update_action(prev_action)

self.prompt_builder.update_observation(obs)

if not self.cached_icl:
messages = self.get_icl_prompt()
else:
messages = []

messages.extend(self.prompt_builder.get_prompt(icl_episodes=True))

naive_instruction = """
You always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates.
""".strip()

if messages and messages[-1].role == "user":
messages[-1].content += "\n\n" + naive_instruction

response = self.client.generate(messages)

final_answer = self._extract_final_answer(response)

return final_answer

def _extract_final_answer(self, answer):
"""Sanitize the final answer, keeping only alphabetic characters.
Args:
answer (LLMResponse): The response from the LLM.
Returns:
LLMResponse: The sanitized response.
"""

def filter_letters(input_string):
return re.sub(r"[^a-zA-Z\s:]", "", input_string)

final_answer = copy.deepcopy(answer)
final_answer = final_answer._replace(completion=filter_letters(final_answer.completion))

return final_answer
5 changes: 4 additions & 1 deletion balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ agent:
max_history: 16 # Maximum number of previous turns to keep in the dialogue history
max_image_history: 0 # Maximum number of images to keep in the history
max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent)
max_icl_history: 1000 # Maximum number of ICL steps to keep in history (if using 'few_shot' type of agent)
cache_icl: False

eval:
output_dir: "results" # Directory where evaluation results will be saved
Expand All @@ -19,7 +21,8 @@ eval:
max_steps_per_episode: null # Max steps per episode; null uses the environment default
save_trajectories: True # Whether to save agent trajectories (text only)
save_images: False # Whether to save images from the environment

icl_episodes: 1
icl_dataset: records

client:
client_name: openai # LLM client to use (e.g., 'openai', 'gemini', 'claude')
Expand Down
102 changes: 102 additions & 0 deletions balrog/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import glob
import logging
import os
import random
import re
from pathlib import Path

import numpy as np


def natural_sort_key(s):
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))]


def choice_excluding(lst, excluded_element):
possible_choices = [item for item in lst if item != excluded_element]
return random.choice(possible_choices)


class InContextDataset:
def __init__(self, config, env_name, original_cwd) -> None:
self.config = config
self.env_name = env_name
self.original_cwd = original_cwd

def icl_episodes(self, task):
demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task
return list(sorted(glob.glob(os.path.join(demos_dir, "**/*.npz"), recursive=True), key=natural_sort_key))

def extract_seed(self, demo_path):
# extract seed from record, example format: `20241201T225823-seed13-rew1.00-len47.npz`
seed = [part.removeprefix("seed") for part in Path(demo_path).stem.split("-") if "seed" in part]
return int(seed[0])

def demo_task(self, task):
# use different task - avoid the case where we put the solution into the context
if self.env_name == "babaisai":
task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task)

return task

def demo_path(self, i, task):
icl_episodes = self.icl_episodes(task)
demo_path = icl_episodes[i % len(icl_episodes)]

# use different seed - avoid the case where we put the solution into the context
if self.env_name == "textworld":
from balrog.environments.textworld import global_textworld_context

textworld_context = global_textworld_context(
tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs
)
next_seed = textworld_context.count[task]
demo_seed = self.extract_seed(demo_path)
if next_seed == demo_seed:
demo_path = self.icl_episodes(task)[i + 1]

return demo_path

def load_episode(self, filename):
# Load the compressed NPZ file
with np.load(filename, allow_pickle=True) as data:
# Convert to dictionary if you want
episode = {k: data[k] for k in data.files}
return episode

def load_in_context_learning_episodes(self, num_episodes, task, agent):
demo_task = self.demo_task(task)
demo_paths = [self.demo_path(i, demo_task) for i in range(len(self.icl_episodes(task)))]
random.shuffle(demo_paths)
demo_paths = demo_paths[:num_episodes]

for demo_path in demo_paths:
self.load_in_context_learning_episode(demo_path, agent)

def load_in_context_learning_episode(self, demo_path, agent):
episode = self.load_episode(demo_path)

actions = episode.pop("action").tolist()
rewards = episode.pop("reward").tolist()
terminated = episode.pop("terminated")
truncated = episode.pop("truncated")
dones = np.any([terminated, truncated], axis=0).tolist()
observations = [dict(zip(episode.keys(), values)) for values in zip(*episode.values())]

# first transition only contains observation (like env.reset())
observation, action, reward, done = observations.pop(0), actions.pop(0), rewards.pop(0), dones.pop(0)
agent.update_icl_observation(observation)

for observation, action, reward, done in zip(observations, actions, rewards, dones):
action = str(action)

agent.update_icl_action(action)
agent.update_icl_observation(observation)

if done:
break

if not done:
logging.info("icl trajectory ended without done")

agent.wrap_episode()
15 changes: 13 additions & 2 deletions balrog/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from omegaconf import OmegaConf
from tqdm import tqdm

from balrog.agents.few_shot import FewShotAgent
from balrog.dataset import InContextDataset
from balrog.environments import make_env
from balrog.utils import get_unique_seed

Expand Down Expand Up @@ -43,7 +45,7 @@ def __init__(self, config, original_cwd="", output_dir="."):
self.env_evaluators = {}
self.tasks = []
for env_name in self.env_names:
evaluator = Evaluator(env_name, config, output_dir=self.output_dir)
evaluator = Evaluator(env_name, config, original_cwd=original_cwd, output_dir=self.output_dir)
self.env_evaluators[env_name] = evaluator
for task in evaluator.tasks:
for episode_idx in range(evaluator.num_episodes):
Expand Down Expand Up @@ -219,7 +221,7 @@ class Evaluator:
including loading in-context learning episodes and running episodes with the agent.
"""

def __init__(self, env_name, config, output_dir="."):
def __init__(self, env_name, config, original_cwd="", output_dir="."):
"""Initialize the Evaluator.
Args:
Expand All @@ -237,6 +239,8 @@ def __init__(self, env_name, config, output_dir="."):
self.num_workers = config.eval.num_workers
self.max_steps_per_episode = config.eval.max_steps_per_episode

self.dataset = InContextDataset(self.config, self.env_name, original_cwd=original_cwd)

def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
"""Run a single evaluation episode.
Expand Down Expand Up @@ -284,6 +288,13 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
csv_writer = csv.writer(csv_file, escapechar="˘", quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(["Step", "Action", "Reasoning", "Observation", "Reward", "Done"])

# If the agent is an FewShotAgent, load the in-context learning episode
if isinstance(agent, FewShotAgent):
self.dataset.load_in_context_learning_episodes(self.config.eval.icl_episodes, task, agent)

if self.config.agent.cache_icl and self.config.client.client_name == "gemini":
agent.cache_icl()

pbar_desc = f"Task: {task}, Proc: {process_num}"
pbar = tqdm(
total=max_steps_per_episode,
Expand Down
5 changes: 4 additions & 1 deletion balrog/prompt_builder/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def reset(self):
"""Clear the event history."""
self._events.clear()

def get_prompt(self) -> List[Message]:
def get_prompt(self, icl_episodes=False) -> List[Message]:
"""Generate a list of Message objects representing the prompt.
Returns:
Expand All @@ -85,6 +85,9 @@ def get_prompt(self) -> List[Message]:
if self.system_prompt:
messages.append(Message(role="user", content=self.system_prompt))

if self.system_prompt and not icl_episodes:
messages.append(Message(role="user", content=self.system_prompt))

# Determine which images to include
images_needed = self.max_image_history
for event in reversed(self._events):
Expand Down
Loading

0 comments on commit 67a8d26

Please sign in to comment.