Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
liqul committed Feb 23, 2024
1 parent 40c8bef commit 59eefb2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 34 deletions.
9 changes: 6 additions & 3 deletions ces_container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ RUN pip install --no-cache-dir -r requirements.txt
# Copy the project code
COPY taskweaver/ces /app/taskweaver/ces
COPY taskweaver/plugin /app/taskweaver/plugin
COPY taskweaver/module /app/taskweaver/module
COPY taskweaver/__init__.py /app/taskweaver/__init__.py
# Add the taskweaver directories to the PYTHONPATH
ENV PYTHONPATH="${PYTHONPATH}:/app/taskweaver"

CMD ["python", "/app/taskweaver/ces/docker_entry.py"]
RUN mv /app/taskweaver/ces/docker_entry.py /app/docker_entry.py

ENV PYTHONPATH "${PYTHONPATH}:/app"

CMD ["python", "docker_entry.py"]


6 changes: 4 additions & 2 deletions taskweaver/ces/docker_entry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys
import time

sys.path.append("/app")

from taskweaver.ces import Environment, EnvMode

# Flag to control the main loop
Expand All @@ -12,7 +15,7 @@
)
env_dir = os.getenv(
"TASKWEAVER_ENV_DIR",
os.path.realpath(os.getcwd()),
"/app",
)
session_id = os.getenv(
"TASKWEAVER_SESSION_ID",
Expand All @@ -29,7 +32,6 @@
"kernel_id",
)


if __name__ == "__main__":
env = Environment(env_id, env_dir, env_mode=EnvMode.InsideContainer)
env.start_session(session_id, port_start=port_start, kernel_id=kernel_id)
Expand Down
79 changes: 50 additions & 29 deletions taskweaver/ces/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import sys
import time
from ast import literal_eval
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Union
Expand Down Expand Up @@ -171,16 +172,6 @@ def _get_connection_file(self, session_id: str, kernel_id: str) -> str:
f"conn-{session_id}-{kernel_id}.json",
)

def _prepare_session(self, cwd, session_dir, session_id):
session = self._get_session(session_id, session_dir=session_dir)
ces_session_dir = os.path.join(session.session_dir, "ces")
kernel_id = get_id(prefix="knl")
os.makedirs(ces_session_dir, exist_ok=True)
connection_file = self._get_connection_file(session_id, kernel_id)
cwd = cwd if cwd is not None else os.path.join(session.session_dir, "cwd")
os.makedirs(cwd, exist_ok=True)
return ces_session_dir, connection_file, cwd, kernel_id, session

def start_session(
self,
session_id: str,
Expand All @@ -190,11 +181,14 @@ def start_session(
port_start: Optional[int] = None,
) -> None:
if self.mode == EnvMode.SubProcess:
ces_session_dir, connection_file, cwd, new_kernel_id, session = self._prepare_session(
cwd,
session_dir,
session_id,
)
session = self._get_session(session_id, session_dir=session_dir)
ces_session_dir = os.path.join(session.session_dir, "ces")
new_kernel_id = get_id(prefix="knl")
os.makedirs(ces_session_dir, exist_ok=True)
connection_file = self._get_connection_file(session_id, new_kernel_id)
cwd = cwd if cwd is not None else os.path.join(session.session_dir, "cwd")
os.makedirs(cwd, exist_ok=True)

# set python home from current python environment
python_home = os.path.sep.join(sys.executable.split(os.path.sep)[:-2])
python_path = os.pathsep.join(
Expand Down Expand Up @@ -235,11 +229,11 @@ def start_session(
session.kernel_status = "ready"

elif self.mode == EnvMode.OutsideContainer:
ces_session_dir, connection_file, cwd, new_kernel_id, session = self._prepare_session(
cwd,
session_dir,
session_id,
)
session = self._get_session(session_id, session_dir=session_dir)
ces_session_dir = os.path.join(session.session_dir, "ces")
new_kernel_id = get_id(prefix="knl")
os.makedirs(ces_session_dir, exist_ok=True)
connection_file = self._get_connection_file(session_id, new_kernel_id)

kernel_env = {
"TASKWEAVER_ENV_ID": self.id,
Expand All @@ -264,6 +258,17 @@ def start_session(
},
)

tick = 0
while tick < 10:
container.reload()
if container.status == "running" and os.path.isfile(connection_file):
print("Container is running and connection file is ready.")
break
time.sleep(1) # wait for 1 second before checking again
tick += 1
if tick == 10:
raise Exception("Container is not ready after 10 seconds")

self.session_container_dict[session_id] = container.id
self.session_container_port_dict[session_id] = self.port_start
self.port_start += 5
Expand All @@ -274,15 +279,29 @@ def start_session(
session.kernel_status = "ready"
elif self.mode == EnvMode.InsideContainer:
assert port_start is not None, "Port start must be provided when inside container."
kernel_env = {
"JUPYTER_SHELL_PORT": str(port_start),
"JUPYTER_IOPUB_PORT": str(port_start + 1),
"JUPYTER_STDIN_PORT": str(port_start + 2),
"JUPYTER_HB_PORT": str(port_start + 3),
"JUPYTER_CONTROL_PORT": str(port_start + 4),
"JUPYTER_KERNEL_IP": "0.0.0.0",
"JUPYTER_MANUAL_PORTS": "True",
}
session = self._get_session(session_id, session_dir=session_dir)
ces_session_dir = os.path.join(session.session_dir, "ces")
connection_file = self._get_connection_file(session_id, kernel_id)
cwd = cwd if cwd is not None else os.path.join(session.session_dir, "cwd")

kernel_env = os.environ.copy()
kernel_env.update(
{
"JUPYTER_SHELL_PORT": str(port_start),
"JUPYTER_IOPUB_PORT": str(port_start + 1),
"JUPYTER_STDIN_PORT": str(port_start + 2),
"JUPYTER_HB_PORT": str(port_start + 3),
"JUPYTER_CONTROL_PORT": str(port_start + 4),
"JUPYTER_KERNEL_IP": "0.0.0.0",
"JUPYTER_MANUAL_PORTS": "True",
"CONNECTION_FILE": connection_file,
"TASKWEAVER_LOGGING_FILE_PATH": os.path.join(
ces_session_dir,
"kernel_logging.log",
),
"PATH": os.environ["PATH"],
},
)

kernel_id = self.multi_kernel_manager.start_kernel(
kernel_id=kernel_id,
Expand Down Expand Up @@ -467,6 +486,8 @@ def _get_client(
print(connection_file)
client = BlockingKernelClient(connection_file=connection_file)
client.load_connection_file()
# overwrite the ip to localhost
client.ip = "127.0.0.1"
return client

def _execute_code_on_kernel(
Expand Down

0 comments on commit 59eefb2

Please sign in to comment.