forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use jupyer-kernel-gateway for ipython executor (microsoft#1748)
* checkpoint async based * Implement jupyter client and use jupyer gateway * update deps * address comments * add missing parenthesis * Update build.yml * CI fixes * change requirement name * debug * print stderr * dont seek * show token * mitigaton for windows bug * use hex token to avoid - in token * formatting * put back in place original while the windows bug exists * lint * Update autogen/coding/jupyter_code_executor.py * Update jupyter_code_executor.py * Update test_embedded_ipython_code_executor.py * Update setup.py * Update build.yml * fix nameerror --------- Co-authored-by: Eric Zhu <[email protected]>
- Loading branch information
1 parent
ac15996
commit fbc2f6e
Showing
11 changed files
with
673 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .base import JupyterConnectable, JupyterConnectionInfo | ||
from .jupyter_client import JupyterClient | ||
from .local_jupyter_server import LocalJupyterServer | ||
|
||
__all__ = ["JupyterConnectable", "JupyterConnectionInfo", "JupyterClient", "LocalJupyterServer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Protocol, runtime_checkable | ||
|
||
|
||
@dataclass | ||
class JupyterConnectionInfo: | ||
"""(Experimental)""" | ||
|
||
host: str | ||
use_https: bool | ||
port: int | ||
token: Optional[str] | ||
|
||
|
||
@runtime_checkable | ||
class JupyterConnectable(Protocol): | ||
"""(Experimental)""" | ||
|
||
@property | ||
def connection_info(self) -> JupyterConnectionInfo: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from types import TracebackType | ||
from typing import Any, Dict, List, Optional, cast | ||
import sys | ||
|
||
if sys.version_info >= (3, 11): | ||
from typing import Self | ||
else: | ||
from typing_extensions import Self | ||
|
||
import json | ||
import uuid | ||
import datetime | ||
import requests | ||
|
||
import websocket | ||
from websocket import WebSocket | ||
|
||
from .base import JupyterConnectionInfo | ||
|
||
|
||
class JupyterClient: | ||
"""(Experimental) A client for communicating with a Jupyter gateway server.""" | ||
|
||
def __init__(self, connection_info: JupyterConnectionInfo): | ||
self._connection_info = connection_info | ||
|
||
def _get_headers(self) -> Dict[str, str]: | ||
if self._connection_info.token is None: | ||
return {} | ||
return {"Authorization": f"token {self._connection_info.token}"} | ||
|
||
def _get_api_base_url(self) -> str: | ||
protocol = "https" if self._connection_info.use_https else "http" | ||
return f"{protocol}://{self._connection_info.host}:{self._connection_info.port}" | ||
|
||
def _get_ws_base_url(self) -> str: | ||
return f"ws://{self._connection_info.host}:{self._connection_info.port}" | ||
|
||
def list_kernel_specs(self) -> Dict[str, Dict[str, str]]: | ||
response = requests.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) | ||
return cast(Dict[str, Dict[str, str]], response.json()) | ||
|
||
def list_kernels(self) -> List[Dict[str, str]]: | ||
response = requests.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) | ||
return cast(List[Dict[str, str]], response.json()) | ||
|
||
def start_kernel(self, kernel_spec_name: str) -> str: | ||
"""Start a new kernel. | ||
Args: | ||
kernel_spec_name (str): Name of the kernel spec to start | ||
Returns: | ||
str: ID of the started kernel | ||
""" | ||
|
||
response = requests.post( | ||
f"{self._get_api_base_url()}/api/kernels", | ||
headers=self._get_headers(), | ||
json={"name": kernel_spec_name}, | ||
) | ||
return cast(str, response.json()["id"]) | ||
|
||
def restart_kernel(self, kernel_id: str) -> None: | ||
response = requests.post( | ||
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() | ||
) | ||
response.raise_for_status() | ||
|
||
def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient: | ||
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" | ||
ws = websocket.create_connection(ws_url, header=self._get_headers()) | ||
return JupyterKernelClient(ws) | ||
|
||
|
||
class JupyterKernelClient: | ||
"""(Experimental) A client for communicating with a Jupyter kernel.""" | ||
|
||
@dataclass | ||
class ExecutionResult: | ||
@dataclass | ||
class DataItem: | ||
mime_type: str | ||
data: str | ||
|
||
is_ok: bool | ||
output: str | ||
data_items: List[DataItem] | ||
|
||
def __init__(self, websocket: WebSocket): | ||
self._session_id: str = uuid.uuid4().hex | ||
self._websocket: WebSocket = websocket | ||
|
||
def __enter__(self) -> Self: | ||
return self | ||
|
||
def __exit__( | ||
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] | ||
) -> None: | ||
self._websocket.close() | ||
|
||
def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str: | ||
timestamp = datetime.datetime.now().isoformat() | ||
message_id = uuid.uuid4().hex | ||
message = { | ||
"header": { | ||
"username": "autogen", | ||
"version": "5.0", | ||
"session": self._session_id, | ||
"msg_id": message_id, | ||
"msg_type": message_type, | ||
"date": timestamp, | ||
}, | ||
"parent_header": {}, | ||
"channel": channel, | ||
"content": content, | ||
"metadata": {}, | ||
"buffers": {}, | ||
} | ||
self._websocket.send_text(json.dumps(message)) | ||
return message_id | ||
|
||
def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]: | ||
self._websocket.settimeout(timeout_seconds) | ||
try: | ||
data = self._websocket.recv() | ||
if isinstance(data, bytes): | ||
data = data.decode("utf-8") | ||
return cast(Dict[str, Any], json.loads(data)) | ||
except websocket.WebSocketTimeoutException: | ||
return None | ||
|
||
def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: | ||
message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request") | ||
while True: | ||
message = self._receive_message(timeout_seconds) | ||
# This means we timed out with no new messages. | ||
if message is None: | ||
return False | ||
if ( | ||
message.get("parent_header", {}).get("msg_id") == message_id | ||
and message["msg_type"] == "kernel_info_reply" | ||
): | ||
return True | ||
|
||
def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult: | ||
message_id = self._send_message( | ||
content={ | ||
"code": code, | ||
"silent": False, | ||
"store_history": True, | ||
"user_expressions": {}, | ||
"allow_stdin": False, | ||
"stop_on_error": True, | ||
}, | ||
channel="shell", | ||
message_type="execute_request", | ||
) | ||
|
||
text_output = [] | ||
data_output = [] | ||
while True: | ||
message = self._receive_message(timeout_seconds) | ||
if message is None: | ||
return JupyterKernelClient.ExecutionResult( | ||
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[] | ||
) | ||
|
||
# Ignore messages that are not for this execution. | ||
if message.get("parent_header", {}).get("msg_id") != message_id: | ||
continue | ||
|
||
msg_type = message["msg_type"] | ||
content = message["content"] | ||
if msg_type in ["execute_result", "display_data"]: | ||
for data_type, data in content["data"].items(): | ||
if data_type == "text/plain": | ||
text_output.append(data) | ||
elif data_type.startswith("image/") or data_type == "text/html": | ||
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data)) | ||
else: | ||
text_output.append(json.dumps(data)) | ||
elif msg_type == "stream": | ||
text_output.append(content["text"]) | ||
elif msg_type == "error": | ||
# Output is an error. | ||
return JupyterKernelClient.ExecutionResult( | ||
is_ok=False, | ||
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", | ||
data_items=[], | ||
) | ||
if msg_type == "status" and content["execution_state"] == "idle": | ||
break | ||
|
||
return JupyterKernelClient.ExecutionResult( | ||
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output | ||
) |
Oops, something went wrong.