Skip to content

Commit

Permalink
machine: move to tpi[ssh] (iterative#6528)
Browse files Browse the repository at this point in the history
* machine: move to tpi[ssh]

- depends on https://github.com/iterative/tpi

* lazy import tpi

* fix lint warning

Co-authored-by: Peter Rowlands <[email protected]>
  • Loading branch information
casperdcl and pmrowla authored Sep 3, 2021
1 parent 2778c05 commit 0578643
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 417 deletions.
36 changes: 25 additions & 11 deletions dvc/machine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
Type,
)

from dvc.exceptions import DvcException
from dvc.types import StrPath

from .backend.base import BaseMachineBackend
from .backend.terraform import TerraformBackend

if TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)
from .backend.base import BaseMachineBackend

BackendCls = Type[BaseMachineBackend]

BackendCls = Type[BaseMachineBackend]
logger = logging.getLogger(__name__)


RESERVED_NAMES = {"local", "localhost"}
Expand All @@ -37,11 +37,16 @@ def validate_name(name: str):


class MachineBackends(Mapping):
DEFAULT: Dict[str, BackendCls] = {
try:
from .backend.terraform import TerraformBackend
except ImportError:
TerraformBackend = None # type: ignore[assignment, misc]

DEFAULT: Dict[str, Optional["BackendCls"]] = {
"terraform": TerraformBackend,
}

def __getitem__(self, key: str) -> BaseMachineBackend:
def __getitem__(self, key: str) -> "BaseMachineBackend":
"""Lazily initialize backends and cache it afterwards"""
initialized = self.initialized.get(key)
if not initialized:
Expand All @@ -59,9 +64,18 @@ def __init__(
**kwargs,
) -> None:
selected = selected or list(self.DEFAULT)
self.backends = {key: self.DEFAULT[key] for key in selected}

self.initialized: Dict[str, BaseMachineBackend] = {}
self.backends: Dict[str, "BackendCls"] = {}
for key in selected:
cls = self.DEFAULT.get(key)
if cls is None:
raise DvcException(
f"'dvc machine' backend '{key}' is missing required "
"dependencies. Install them with:\n"
f"\tpip install dvc[{key}]"
)
self.backends[key] = cls

self.initialized: Dict[str, "BaseMachineBackend"] = {}

self.tmp_dir = tmp_dir
self.kwargs = kwargs
Expand Down Expand Up @@ -142,7 +156,7 @@ def _get_config(self, **kwargs):
conf = kwargs
return conf

def _get_backend(self, cloud: str) -> BaseMachineBackend:
def _get_backend(self, cloud: str) -> "BaseMachineBackend":
from dvc.config import NoMachineError

try:
Expand Down
67 changes: 3 additions & 64 deletions dvc/machine/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,15 @@
import asyncio
import logging
import os
import sys
from abc import ABC, abstractmethod
from abc import abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional

from dvc.exceptions import DvcException
from dvc.types import StrPath
from dvc.utils.fs import makedirs
from tpi import base

if TYPE_CHECKING:
from dvc.fs.ssh import SSHFileSystem
from dvc.repo.experiments.executor.base import BaseExecutor

logger = logging.getLogger(__name__)


class BaseMachineBackend(ABC):
def __init__(self, tmp_dir: StrPath, **kwargs):
self.tmp_dir = tmp_dir
makedirs(self.tmp_dir, exist_ok=True)

@abstractmethod
def create(self, name: Optional[str] = None, **config):
"""Create and start an instance of the specified machine."""

@abstractmethod
def destroy(self, name: Optional[str] = None, **config):
"""Stop and destroy all instances of the specified machine."""

@abstractmethod
def instances(
self, name: Optional[str] = None, **config
) -> Iterator[dict]:
"""Iterate over status of all instances of the specified machine."""

def close(self):
pass

class BaseMachineBackend(base.BaseMachineBackend):
@abstractmethod
def get_executor(
self, name: Optional[str] = None, **config
Expand All @@ -54,35 +25,3 @@ def get_sshfs(
) -> Iterator["SSHFileSystem"]:
"""Return an sshfs instance for the default directory on the
specified machine."""

@abstractmethod
def run_shell(self, name: Optional[str] = None, **config):
"""Spawn an interactive SSH shell for the specified machine."""

def _shell(self, *args, **kwargs):
"""Sync wrapper for an asyncssh shell session.
Args will be passed into asyncssh.connect().
"""
import asyncssh

loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.run_until_complete(self._shell_async(*args, **kwargs))
except (OSError, asyncssh.Error) as exc:
raise DvcException("SSH connection failed") from exc
finally:
asyncio.set_event_loop(None)
loop.close()

async def _shell_async(self, *args, **kwargs):
import asyncssh

async with asyncssh.connect(*args, **kwargs) as conn:
await conn.run(
term_type=os.environ.get("TERM", "xterm"),
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr,
)
80 changes: 4 additions & 76 deletions dvc/machine/backend/terraform.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,29 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional

from funcy import first
from tpi import TerraformProviderIterative, terraform

from dvc.fs.ssh import SSHFileSystem

from .base import BaseMachineBackend

if TYPE_CHECKING:
from dvc.repo.experiments.executor.base import BaseExecutor

logger = logging.getLogger(__name__)


class TerraformBackend(BaseMachineBackend):
@contextmanager
def make_tf(self, name: str):
from dvc.tpi import DvcTerraform, TerraformError
from dvc.utils.fs import makedirs

try:
working_dir = os.path.join(self.tmp_dir, name)
makedirs(working_dir, exist_ok=True)
yield DvcTerraform(working_dir=working_dir)
except TerraformError:
raise
except Exception as exc:
raise TerraformError("terraform failed") from exc

def create(self, name: Optional[str] = None, **config):
from python_terraform import IsFlagged

from dvc.tpi import render_json

assert name and "cloud" in config
with self.make_tf(name) as tf:
tf_file = os.path.join(tf.working_dir, "main.tf.json")
with open(tf_file, "w", encoding="utf-8") as fobj:
fobj.write(render_json(name=name, **config, indent=2))
tf.cmd("init")
tf.cmd("apply", auto_approve=IsFlagged)

def destroy(self, name: Optional[str] = None, **config):
from python_terraform import IsFlagged

assert name

with self.make_tf(name) as tf:
if first(tf.iter_instances(name)):
tf.cmd("destroy", auto_approve=IsFlagged)

def instances(
self, name: Optional[str] = None, **config
) -> Iterator[dict]:
assert name

with self.make_tf(name) as tf:
yield from tf.iter_instances(name)

class TerraformBackend(terraform.TerraformBackend):
def get_executor(
self, name: Optional[str] = None, **config
) -> "BaseExecutor":
raise NotImplementedError

def _default_resource(self, name):
from dvc.tpi import TerraformError

resource = first(self.instances(name))
if not resource:
raise TerraformError(f"No active '{name}' instances")
return resource

@contextmanager
def get_sshfs(
def get_sshfs( # pylint: disable=unused-argument
self, name: Optional[str] = None, **config
) -> Iterator["SSHFileSystem"]:
from dvc.tpi import DvcTerraform

resource = self._default_resource(name)
with DvcTerraform.pemfile(resource) as pem:
with TerraformProviderIterative.pemfile(resource) as pem:
fs = SSHFileSystem(
host=resource["instance_ip"],
user="ubuntu",
keyfile=pem,
)
yield fs

def run_shell(self, name: Optional[str] = None, **config):
from dvc.tpi import DvcTerraform

resource = self._default_resource(name)
with DvcTerraform.pemfile(resource) as pem:
self._shell(
host=resource["instance_ip"],
username="ubuntu",
client_keys=pem,
known_hosts=None,
)
96 changes: 0 additions & 96 deletions dvc/tpi/__init__.py

This file was deleted.

26 changes: 0 additions & 26 deletions dvc/tpi/templates/main.tf

This file was deleted.

Loading

0 comments on commit 0578643

Please sign in to comment.