forked from iterative/dvc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
git: use asyncssh as dulwich SSH vendor (iterative#6797)
* git: support client kwargs for auth in remote methods * tests: use docker for git/ssh server * move dulwich backend into subdir * add base sync<->async support * dulwich: add initial asyncssh SSH vendor * dulwich: use asyncssh as the default SSH vendor * add test for git over ssh * test read (fetch) over ssh * ci: re-enable ssh tests temporary workaround for iterative#6788, iterative#6789 * move async stubs into dvc.scm.asyn * tests: use sshfs and not dvc.fs.ssh * allow explicit 0-length read in asyncssh vendor
- Loading branch information
Showing
12 changed files
with
294 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,7 +109,7 @@ jobs: | |
run: >- | ||
python -m tests -n=4 | ||
--cov-report=xml --cov-report=term | ||
${{ env.extra_test_args }} | ||
--enable-ssh ${{ env.extra_test_args }} | ||
- name: upload coverage report | ||
uses: codecov/[email protected] | ||
with: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""DVC re-implementation of fsspec's dedicated async event loop.""" | ||
import asyncio | ||
import os | ||
import threading | ||
from typing import List, Optional | ||
|
||
from fsspec.asyn import ( # noqa: F401, pylint:disable=unused-import | ||
_selector_policy, | ||
sync, | ||
sync_wrapper, | ||
) | ||
|
||
# dedicated async IO thread | ||
iothread: List[Optional[threading.Thread]] = [None] | ||
# global DVC event loop | ||
default_loop: List[Optional[asyncio.AbstractEventLoop]] = [None] | ||
lock = threading.Lock() | ||
|
||
|
||
def get_loop() -> asyncio.AbstractEventLoop: | ||
"""Create or return the global DVC event loop.""" | ||
if default_loop[0] is None: | ||
with lock: | ||
if default_loop[0] is None: | ||
with _selector_policy(): | ||
default_loop[0] = asyncio.new_event_loop() | ||
loop = default_loop[0] | ||
th = threading.Thread( | ||
target=loop.run_forever, # type: ignore[attr-defined] | ||
name="dvcIO", | ||
) | ||
th.daemon = True | ||
th.start() | ||
iothread[0] = th | ||
assert default_loop[0] is not None | ||
return default_loop[0] | ||
|
||
|
||
class BaseAsyncObject: | ||
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): | ||
self._loop: asyncio.AbstractEventLoop = loop or get_loop() | ||
self._pid = os.getpid() | ||
|
||
@property | ||
def loop(self): | ||
# AsyncMixin is not fork-safe | ||
assert self._pid == os.getpid() | ||
return self._loop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
"""asyncssh SSH vendor for Dulwich.""" | ||
from typing import List, Optional | ||
|
||
from dulwich.client import SSHVendor | ||
|
||
from dvc.scm.asyn import BaseAsyncObject, sync_wrapper | ||
|
||
|
||
class _StderrWrapper: | ||
def __init__(self, stderr): | ||
self.stderr = stderr | ||
|
||
async def _readlines(self): | ||
lines = [] | ||
while True: | ||
line = await self.stderr.readline() | ||
if not line: | ||
break | ||
lines.append(line) | ||
return lines | ||
|
||
readlines = sync_wrapper(_readlines) | ||
|
||
|
||
class AsyncSSHWrapper(BaseAsyncObject): | ||
def __init__(self, conn, proc, **kwargs): | ||
super().__init__(**kwargs) | ||
self.conn = conn | ||
self.proc = proc | ||
self.stderr = _StderrWrapper(proc.stderr) | ||
|
||
def can_read(self) -> bool: | ||
# pylint:disable=protected-access | ||
return self.proc.stdout._session._recv_buf_len > 0 | ||
|
||
async def _read(self, n: Optional[int] = None) -> bytes: | ||
if self.proc.stdout.at_eof(): | ||
return b"" | ||
|
||
return await self.proc.stdout.read(n=n if n is not None else -1) | ||
|
||
read = sync_wrapper(_read) | ||
|
||
def write(self, data: bytes): | ||
self.proc.stdin.write(data) | ||
|
||
def close(self): | ||
self.conn.close() | ||
|
||
|
||
class AsyncSSHVendor(BaseAsyncObject, SSHVendor): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
async def _run_command( | ||
self, | ||
host: str, | ||
command: List[str], | ||
username: Optional[str] = None, | ||
port: Optional[int] = None, | ||
password: Optional[str] = None, | ||
key_filename: Optional[str] = None, | ||
**kwargs, | ||
): | ||
"""Connect to an SSH server. | ||
Run a command remotely and return a file-like object for interaction | ||
with the remote command. | ||
Args: | ||
host: Host name | ||
command: Command to run (as argv array) | ||
username: Optional ame of user to log in as | ||
port: Optional SSH port to use | ||
password: Optional ssh password for login or private key | ||
key_filename: Optional path to private keyfile | ||
""" | ||
import asyncssh | ||
|
||
conn = await asyncssh.connect( | ||
host, | ||
port=port, | ||
username=username, | ||
password=password, | ||
client_keys=[key_filename] if key_filename else [], | ||
known_hosts=None, | ||
encoding=None, | ||
) | ||
proc = await conn.create_process(command, encoding=None) | ||
return AsyncSSHWrapper(conn, proc) | ||
|
||
run_command = sync_wrapper(_run_command) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.