Skip to content

Commit

Permalink
Revert "Feature: Add ability to reconnect websockets" (All-Hands-AI#4801
Browse files Browse the repository at this point in the history
)
  • Loading branch information
enyst authored Nov 7, 2024
1 parent 2b3fd94 commit 47464a9
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 60 deletions.
32 changes: 13 additions & 19 deletions frontend/src/context/socket.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ import React from "react";
import { Data } from "ws";
import EventLogger from "#/utils/event-logger";

const RECONNECT_RETRIES = 5;

interface WebSocketClientOptions {
token: string | null;
onOpen?: (event: Event, isNewSession: boolean) => void;
onOpen?: (event: Event) => void;
onMessage?: (event: MessageEvent<Data>) => void;
onError?: (event: Event) => void;
onClose?: (event: Event) => void;
Expand All @@ -16,8 +14,8 @@ interface WebSocketContextType {
send: (data: string | ArrayBufferLike | Blob | ArrayBufferView) => void;
start: (options?: WebSocketClientOptions) => void;
stop: () => void;
setRuntimeIsInitialized: (runtimeIsInitialized: boolean) => void;
runtimeIsInitialized: boolean;
setRuntimeIsInitialized: () => void;
runtimeActive: boolean;
isConnected: boolean;
events: Record<string, unknown>[];
}
Expand All @@ -32,11 +30,14 @@ interface SocketProviderProps {

function SocketProvider({ children }: SocketProviderProps) {
const wsRef = React.useRef<WebSocket | null>(null);
const wsReconnectRetries = React.useRef<number>(RECONNECT_RETRIES);
const [isConnected, setIsConnected] = React.useState(false);
const [runtimeIsInitialized, setRuntimeIsInitialized] = React.useState(false);
const [runtimeActive, setRuntimeActive] = React.useState(false);
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);

const setRuntimeIsInitialized = () => {
setRuntimeActive(true);
};

const start = React.useCallback((options?: WebSocketClientOptions): void => {
if (wsRef.current) {
EventLogger.warning(
Expand All @@ -58,9 +59,7 @@ function SocketProvider({ children }: SocketProviderProps) {

ws.addEventListener("open", (event) => {
setIsConnected(true);
const isNewSession = sessionToken === "NO_JWT";
wsReconnectRetries.current = RECONNECT_RETRIES;
options?.onOpen?.(event, isNewSession);
options?.onOpen?.(event);
});

ws.addEventListener("message", (event) => {
Expand All @@ -77,22 +76,17 @@ function SocketProvider({ children }: SocketProviderProps) {

ws.addEventListener("close", (event) => {
EventLogger.event(event, "SOCKET CLOSE");

setIsConnected(false);
setRuntimeIsInitialized(false);
setRuntimeActive(false);
wsRef.current = null;
options?.onClose?.(event);
if (wsReconnectRetries.current) {
wsReconnectRetries.current -= 1;
const token = localStorage.getItem("token");
setTimeout(() => start({ ...(options || {}), token }), 1);
}
});

wsRef.current = ws;
}, []);

const stop = React.useCallback((): void => {
wsReconnectRetries.current = 0;
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
Expand All @@ -117,7 +111,7 @@ function SocketProvider({ children }: SocketProviderProps) {
start,
stop,
setRuntimeIsInitialized,
runtimeIsInitialized,
runtimeActive,
isConnected,
events,
}),
Expand All @@ -126,7 +120,7 @@ function SocketProvider({ children }: SocketProviderProps) {
start,
stop,
setRuntimeIsInitialized,
runtimeIsInitialized,
runtimeActive,
isConnected,
events,
],
Expand Down
48 changes: 18 additions & 30 deletions frontend/src/routes/_oh.app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ function App() {
const { files, importedProjectZip } = useSelector(
(state: RootState) => state.initalQuery,
);
const { start, send, setRuntimeIsInitialized, runtimeIsInitialized } =
useSocket();
const { start, send, setRuntimeIsInitialized, runtimeActive } = useSocket();
const { settings, token, ghToken, repo, q, lastCommit } =
useLoaderData<typeof clientLoader>();
const fetcher = useFetcher();
Expand Down Expand Up @@ -162,32 +161,21 @@ function App() {
);
};

const doSendInitialQuery = React.useRef<boolean>(true);

const sendInitialQuery = (query: string, base64Files: string[]) => {
const timestamp = new Date().toISOString();
send(createChatMessage(query, base64Files, timestamp));
};

const handleOpen = React.useCallback(
(event: Event, isNewSession: boolean) => {
if (!isNewSession) {
dispatch(clearMessages());
dispatch(clearTerminal());
dispatch(clearJupyter());
}
doSendInitialQuery.current = isNewSession;
const initEvent = {
action: ActionType.INIT,
args: settings,
};
send(JSON.stringify(initEvent));

// display query in UI, but don't send it to the server
if (q && isNewSession) addIntialQueryToChat(q, files);
},
[settings],
);
const handleOpen = React.useCallback(() => {
const initEvent = {
action: ActionType.INIT,
args: settings,
};
send(JSON.stringify(initEvent));

// display query in UI, but don't send it to the server
if (q) addIntialQueryToChat(q, files);
}, [settings]);

const handleMessage = React.useCallback(
(message: MessageEvent<WebSocket.Data>) => {
Expand Down Expand Up @@ -230,7 +218,7 @@ function App() {
isAgentStateChange(parsed) &&
parsed.extras.agent_state === AgentState.INIT
) {
setRuntimeIsInitialized(true);
setRuntimeIsInitialized();

// handle new session
if (!token) {
Expand All @@ -245,7 +233,7 @@ function App() {
additionalInfo = `Files have been uploaded. Please check the /workspace for files.`;
}

if (q && doSendInitialQuery.current) {
if (q) {
if (additionalInfo) {
sendInitialQuery(`${q}\n\n[${additionalInfo}]`, files);
} else {
Expand Down Expand Up @@ -277,15 +265,15 @@ function App() {
});

React.useEffect(() => {
if (runtimeIsInitialized && userId && ghToken) {
if (runtimeActive && userId && ghToken) {
// Export if the user valid, this could happen mid-session so it is handled here
send(getGitHubTokenCommand(ghToken));
}
}, [userId, ghToken, runtimeIsInitialized]);
}, [userId, ghToken, runtimeActive]);

React.useEffect(() => {
(async () => {
if (runtimeIsInitialized && importedProjectZip) {
if (runtimeActive && importedProjectZip) {
// upload files action
try {
const blob = base64ToBlob(importedProjectZip);
Expand All @@ -299,7 +287,7 @@ function App() {
}
}
})();
}, [runtimeIsInitialized, importedProjectZip]);
}, [runtimeActive, importedProjectZip]);

const {
isOpen: securityModalIsOpen,
Expand All @@ -315,7 +303,7 @@ function App() {
className={cn(
"w-2 h-2 rounded-full border",
"absolute left-3 top-3",
runtimeIsInitialized
runtimeActive
? "bg-green-800 border-green-500"
: "bg-red-800 border-red-500",
)}
Expand Down
1 change: 0 additions & 1 deletion openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
max_iterations=max_iterations,
confirmation_mode=confirmation_mode,
)

self.max_budget_per_task = max_budget_per_task
self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
self.agent_configs = agent_configs if agent_configs else {}
Expand Down
3 changes: 1 addition & 2 deletions openhands/server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,11 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
logger.info(f'Existing session: {sid}')
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')

logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})

Expand Down
6 changes: 3 additions & 3 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ async def _start(
config: AppConfig,
agent: Agent,
max_iterations: int,
max_budget_per_task: float | None,
agent_to_llm_config: dict[str, LLMConfig] | None,
agent_configs: dict[str, AgentConfig] | None,
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
):
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
Expand Down
10 changes: 5 additions & 5 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Session:
sid: str
websocket: WebSocket | None
last_active_ts: int = 0
is_alive: bool = False
is_alive: bool = True
agent_session: AgentSession
loop: asyncio.AbstractEventLoop

Expand Down Expand Up @@ -109,7 +109,6 @@ async def _initialize_agent(self, data: dict):

# Create the agent session
try:
self.is_alive = True
await self.agent_session.start(
runtime_name=self.config.runtime,
config=self.config,
Expand Down Expand Up @@ -139,7 +138,9 @@ async def on_event(self, event: Event):
return
if event.source == EventSource.AGENT:
await self.send(event_to_dict(event))
elif event.source == EventSource.USER and isinstance(event, CmdOutputObservation):
elif event.source == EventSource.USER and isinstance(
event, CmdOutputObservation
):
await self.send(event_to_dict(event))
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
Expand All @@ -158,8 +159,7 @@ async def on_event(self, event: Event):
async def dispatch(self, data: dict):
action = data.get('action', '')
if action == ActionType.INIT:
if not self.is_alive:
await self._initialize_agent(data)
await self._initialize_agent(data)
return
event = event_from_dict(data.copy())
# This checks if the model supports images
Expand Down

0 comments on commit 47464a9

Please sign in to comment.