Skip to content

Commit

Permalink
[Arch proposal] ENVIRONMENT event source (All-Hands-AI#4584)
Browse files Browse the repository at this point in the history
Co-authored-by: Xingyao Wang <[email protected]>
  • Loading branch information
enyst and xingyaoww authored Oct 31, 2024
1 parent db4e1db commit 0687608
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 52 deletions.
8 changes: 5 additions & 3 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def report_error(self, message: str, exception: Exception | None = None):
if exception is not None and isinstance(exception, litellm.AuthenticationError):
detail = 'Please check your credentials. Is your API key correct?'
self.event_stream.add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.USER
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
)

async def start_step_loop(self):
Expand Down Expand Up @@ -346,7 +346,8 @@ async def set_agent_state_to(self, new_state: AgentState):

self.state.agent_state = new_state
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
AgentStateChangedObservation('', self.state.agent_state),
EventSource.ENVIRONMENT,
)

if new_state == AgentState.INIT and self.state.resume_state:
Expand Down Expand Up @@ -423,7 +424,8 @@ async def _step(self) -> None:
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
self.event_stream.add_event(
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
FatalErrorObservation('Agent got stuck in a loop'),
EventSource.ENVIRONMENT,
)
return

Expand Down
4 changes: 2 additions & 2 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def display_event(event: Event):
if hasattr(event, 'thought'):
display_message(event.thought)
if isinstance(event, MessageAction):
if event.source != EventSource.USER:
if event.source == EventSource.AGENT:
display_message(event.content)
if isinstance(event, CmdRunAction):
display_command(event.command)
Expand Down Expand Up @@ -131,7 +131,7 @@ async def prompt_for_next_task():
next_message = input('How can I help? >> ')
if next_message == 'exit':
event_stream.add_event(
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
)
return
action = MessageAction(content=next_message)
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def serialize_model(self):


class Message(BaseModel):
# NOTE: this is not the same as EventSource
# These are the roles in the LLM's APIs
role: Literal['user', 'system', 'assistant', 'tool']
content: list[TextContent | ImageContent] = Field(default_factory=list)
cache_enabled: bool = False
Expand Down
1 change: 1 addition & 0 deletions openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class EventSource(str, Enum):
AGENT = 'agent'
USER = 'user'
ENVIRONMENT = 'environment'


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ async def on_event(self, event: Event) -> None:
)
observation._cause = event.id # type: ignore[attr-defined]
observation.tool_call_metadata = event.tool_call_metadata

# this might be unnecessary, since source should be set by the event stream when we're here
source = event.source if event.source else EventSource.AGENT
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]

Expand Down
1 change: 1 addition & 0 deletions openhands/security/invariant/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def confirm(self, event: Event) -> None:
new_event = action_from_dict(
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
)
# we should confirm only on agent actions
event_source = event.source if event.source else EventSource.AGENT
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)

Expand Down
2 changes: 1 addition & 1 deletion openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def _start(
agent_configs=agent_configs,
)
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
if self.controller:
self.controller.agent_task = self.controller.start_step_loop()
Expand Down
18 changes: 13 additions & 5 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ async def loop_recv(self):

async def _initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
)
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
args = {key: value for key, value in data.get('args', {}).items()}
Expand Down Expand Up @@ -138,12 +139,19 @@ async def on_event(self, event: Event):
return
if event.source == EventSource.AGENT:
await self.send(event_to_dict(event))
elif event.source == EventSource.USER and isinstance(
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, CmdOutputObservation
):
await self.send(event_to_dict(event))
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
elif isinstance(event, ErrorObservation):
await self.send(event_to_dict(event))
# send error events as agent events to the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)

async def dispatch(self, data: dict):
action = data.get('action', '')
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ async def on_event(event: Event):
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
await event_stream.async_add_event(
non_fatal_error_obs, EventSource.ENVIRONMENT
)

event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream
Expand Down
Loading

0 comments on commit 0687608

Please sign in to comment.