Skip to content

Commit

Permalink
git: use asyncssh as dulwich SSH vendor (iterative#6797)
Browse files Browse the repository at this point in the history
* git: support client kwargs for auth in remote methods

* tests: use docker for git/ssh server

* move dulwich backend into subdir

* add base sync<->async support

* dulwich: add initial asyncssh SSH vendor

* dulwich: use asyncssh as the default SSH vendor

* add test for git over ssh

* test read (fetch) over ssh

* ci: re-enable ssh tests

temporary workaround for iterative#6788, iterative#6789

* move async stubs into dvc.scm.asyn

* tests: use sshfs and not dvc.fs.ssh

* allow explicit 0-length read in asyncssh vendor
  • Loading branch information
pmrowla authored Oct 21, 2021
1 parent b652670 commit 73f2fa8
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
run: >-
python -m tests -n=4
--cov-report=xml --cov-report=term
${{ env.extra_test_args }}
--enable-ssh ${{ env.extra_test_args }}
- name: upload coverage report
uses: codecov/[email protected]
with:
Expand Down
48 changes: 48 additions & 0 deletions dvc/scm/asyn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""DVC re-implementation of fsspec's dedicated async event loop."""
import asyncio
import os
import threading
from typing import List, Optional

from fsspec.asyn import ( # noqa: F401, pylint:disable=unused-import
_selector_policy,
sync,
sync_wrapper,
)

# dedicated async IO thread
iothread: List[Optional[threading.Thread]] = [None]
# global DVC event loop
default_loop: List[Optional[asyncio.AbstractEventLoop]] = [None]
lock = threading.Lock()


def get_loop() -> asyncio.AbstractEventLoop:
"""Create or return the global DVC event loop."""
if default_loop[0] is None:
with lock:
if default_loop[0] is None:
with _selector_policy():
default_loop[0] = asyncio.new_event_loop()
loop = default_loop[0]
th = threading.Thread(
target=loop.run_forever, # type: ignore[attr-defined]
name="dvcIO",
)
th.daemon = True
th.start()
iothread[0] = th
assert default_loop[0] is not None
return default_loop[0]


class BaseAsyncObject:
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop: asyncio.AbstractEventLoop = loop or get_loop()
self._pid = os.getpid()

@property
def loop(self):
# AsyncMixin is not fork-safe
assert self._pid == os.getpid()
return self._loop
6 changes: 4 additions & 2 deletions dvc/scm/git/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def iter_refs(self, base: Optional[str] = None):
"""

@abstractmethod
def iter_remote_refs(self, url: str, base: Optional[str] = None):
def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
"""Iterate over all refs in the specified remote Git repo.
If base is specified, only refs which begin with base will be yielded.
Expand All @@ -197,6 +197,7 @@ def push_refspec(
dest: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
"""Push refspec to a remote Git repo.
Expand All @@ -222,6 +223,7 @@ def fetch_refspecs(
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
"""Fetch refspecs from a remote Git repo.
Expand Down Expand Up @@ -349,5 +351,5 @@ def merge(
"""

@abstractmethod
def validate_git_remote(self, url: str):
def validate_git_remote(self, url: str, **kwargs):
"""Verify that url is a valid git URL or remote name."""
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from dvc.scm.base import GitAuthError, InvalidRemoteSCMRepo, SCMError
from dvc.utils import relpath

from ..objects import GitObject
from .base import BaseGitBackend
from ...objects import GitObject
from ..base import BaseGitBackend

if TYPE_CHECKING:
from dvc.types import StrPath

from ..objects import GitCommit
from ...objects import GitCommit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,6 +80,13 @@ def sha(self) -> str:
class DulwichBackend(BaseGitBackend): # pylint:disable=abstract-method
"""Dulwich Git backend."""

from dulwich import client

from .asyncssh_vendor import AsyncSSHVendor

# monkeypatch dulwich client's default SSH vendor to use asyncssh
client.get_ssh_vendor = AsyncSSHVendor

# Dulwich progress will return messages equivalent to git CLI,
# our pbars should just display the messages as formatted by dulwich
BAR_FMT_NOTOTAL = "{desc}{bar:b}|{postfix[info]} [{elapsed}]"
Expand Down Expand Up @@ -354,14 +361,14 @@ def iter_refs(self, base: Optional[str] = None):
else:
yield os.fsdecode(key)

def iter_remote_refs(self, url: str, base: Optional[str] = None):
def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
from dulwich.client import HTTPUnauthorized, get_transport_and_path
from dulwich.errors import NotGitRepository
from dulwich.porcelain import get_remote_repo

try:
_remote, location = get_remote_repo(self.repo, url)
client, path = get_transport_and_path(location)
client, path = get_transport_and_path(location, **kwargs)
except Exception as exc:
raise InvalidRemoteSCMRepo(url) from exc

Expand Down Expand Up @@ -389,6 +396,7 @@ def push_refspec(
dest: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
from dulwich.client import HTTPUnauthorized, get_transport_and_path
from dulwich.errors import NotGitRepository, SendPackError
Expand All @@ -402,7 +410,7 @@ def push_refspec(

try:
_remote, location = get_remote_repo(self.repo, url)
client, path = get_transport_and_path(location)
client, path = get_transport_and_path(location, **kwargs)
except Exception as exc:
raise SCMError(
f"'{url}' is not a valid Git remote or URL"
Expand Down Expand Up @@ -476,6 +484,7 @@ def fetch_refspecs(
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
from dulwich.client import get_transport_and_path
from dulwich.objectspec import parse_reftuples
Expand Down Expand Up @@ -504,7 +513,7 @@ def determine_wants(remote_refs):

try:
_remote, location = get_remote_repo(self.repo, url)
client, path = get_transport_and_path(location)
client, path = get_transport_and_path(location, **kwargs)
except Exception as exc:
raise SCMError(
f"'{url}' is not a valid Git remote or URL"
Expand Down Expand Up @@ -659,13 +668,13 @@ def merge(
) -> Optional[str]:
raise NotImplementedError

def validate_git_remote(self, url: str):
def validate_git_remote(self, url: str, **kwargs):
from dulwich.client import LocalGitClient, get_transport_and_path
from dulwich.porcelain import get_remote_repo

try:
_, location = get_remote_repo(self.repo, url)
client, path = get_transport_and_path(location)
client, path = get_transport_and_path(location, **kwargs)
except Exception as exc:
raise InvalidRemoteSCMRepo(url) from exc
if isinstance(client, LocalGitClient) and not os.path.exists(
Expand Down
92 changes: 92 additions & 0 deletions dvc/scm/git/backend/dulwich/asyncssh_vendor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""asyncssh SSH vendor for Dulwich."""
from typing import List, Optional

from dulwich.client import SSHVendor

from dvc.scm.asyn import BaseAsyncObject, sync_wrapper


class _StderrWrapper:
def __init__(self, stderr):
self.stderr = stderr

async def _readlines(self):
lines = []
while True:
line = await self.stderr.readline()
if not line:
break
lines.append(line)
return lines

readlines = sync_wrapper(_readlines)


class AsyncSSHWrapper(BaseAsyncObject):
def __init__(self, conn, proc, **kwargs):
super().__init__(**kwargs)
self.conn = conn
self.proc = proc
self.stderr = _StderrWrapper(proc.stderr)

def can_read(self) -> bool:
# pylint:disable=protected-access
return self.proc.stdout._session._recv_buf_len > 0

async def _read(self, n: Optional[int] = None) -> bytes:
if self.proc.stdout.at_eof():
return b""

return await self.proc.stdout.read(n=n if n is not None else -1)

read = sync_wrapper(_read)

def write(self, data: bytes):
self.proc.stdin.write(data)

def close(self):
self.conn.close()


class AsyncSSHVendor(BaseAsyncObject, SSHVendor):
def __init__(self, **kwargs):
super().__init__(**kwargs)

async def _run_command(
self,
host: str,
command: List[str],
username: Optional[str] = None,
port: Optional[int] = None,
password: Optional[str] = None,
key_filename: Optional[str] = None,
**kwargs,
):
"""Connect to an SSH server.
Run a command remotely and return a file-like object for interaction
with the remote command.
Args:
host: Host name
command: Command to run (as argv array)
username: Optional ame of user to log in as
port: Optional SSH port to use
password: Optional ssh password for login or private key
key_filename: Optional path to private keyfile
"""
import asyncssh

conn = await asyncssh.connect(
host,
port=port,
username=username,
password=password,
client_keys=[key_filename] if key_filename else [],
known_hosts=None,
encoding=None,
)
proc = await conn.create_process(command, encoding=None)
return AsyncSSHWrapper(conn, proc)

run_command = sync_wrapper(_run_command)
6 changes: 4 additions & 2 deletions dvc/scm/git/backend/gitpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def iter_refs(self, base: Optional[str] = None):
for ref in Reference.iter_items(self.repo, common_path=base):
yield ref.path

def iter_remote_refs(self, url: str, base: Optional[str] = None):
def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
raise NotImplementedError

def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
Expand All @@ -465,6 +465,7 @@ def push_refspec(
dest: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
raise NotImplementedError

Expand All @@ -474,6 +475,7 @@ def fetch_refspecs(
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
raise NotImplementedError

Expand Down Expand Up @@ -627,5 +629,5 @@ def merge(
raise SCMError("Merge failed") from exc
return None

def validate_git_remote(self, url: str):
def validate_git_remote(self, url: str, **kwargs):
raise NotImplementedError
6 changes: 4 additions & 2 deletions dvc/scm/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def push_refspec(
dest: str,
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
raise NotImplementedError

Expand All @@ -383,6 +384,7 @@ def fetch_refspecs(
refspecs: Iterable[str],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
**kwargs,
):
raise NotImplementedError

Expand Down Expand Up @@ -537,7 +539,7 @@ def checkout_index(
index.add(entry.path)
index.write()

def iter_remote_refs(self, url: str, base: Optional[str] = None):
def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
raise NotImplementedError

def status(
Expand Down Expand Up @@ -584,5 +586,5 @@ def merge(
self.repo.index.write()
return None

def validate_git_remote(self, url: str):
def validate_git_remote(self, url: str, **kwargs):
raise NotImplementedError
11 changes: 11 additions & 0 deletions tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ services:
- 2222
volumes:
- ./remotes/user.key.pub:/tmp/key

git-server:
image: ghcr.io/linuxserver/openssh-server
environment:
- USER_NAME=user
- PUBLIC_KEY_FILE=/tmp/key
ports:
- 2222
volumes:
- ./remotes/user.key.pub:/tmp/key
- ./remotes/git-init:/config/custom-cont-init.d
Loading

0 comments on commit 73f2fa8

Please sign in to comment.