Skip to content

Commit

Permalink
Switch to typed classes for connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
Fizzadar committed Jan 14, 2024
1 parent 3c38ba8 commit 331865a
Show file tree
Hide file tree
Showing 26 changed files with 2,093 additions and 2,026 deletions.
102 changes: 63 additions & 39 deletions pyinfra/api/command.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import shlex
from inspect import getfullargspec
from string import Formatter
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Union, Unpack

import gevent

from pyinfra.context import ctx_config, ctx_host

from .arguments import get_executor_kwarg_keys
from .arguments import ConnectorArguments

if TYPE_CHECKING:
from pyinfra.api.host import Host
from pyinfra.api.state import State


def make_formatted_string_command(string: str, *args, **kwargs):
def make_formatted_string_command(string: str, *args, **kwargs) -> "StringCommand":
"""
Helper function that takes a shell command or script as a string, splits it
using ``shlex.split`` and then formats each bit, returning a ``StringCommand``
Expand Down Expand Up @@ -50,39 +50,46 @@ class MaskString(str):


class QuoteString:
def __init__(self, obj):
self.object = obj
obj: Union[str, "StringCommand"]

def __repr__(self):
return "QuoteString({0})".format(self.object)
def __init__(self, obj: Union[str, "StringCommand"]):
self.obj = obj

def __repr__(self) -> str:
return f"QuoteString({self.obj})"


class PyinfraCommand:
def __init__(self, *args, **kwargs):
self.executor_kwargs = {
key: kwargs[key] for key in get_executor_kwarg_keys() if key in kwargs
}
connector_arguments: ConnectorArguments

def __init__(self, **arguments: Unpack[ConnectorArguments]):
self.connector_arguments = arguments

def __eq__(self, other):
def __eq__(self, other) -> bool:
if isinstance(other, self.__class__) and repr(self) == repr(other):
return True
return False

def execute(self, state: "State", host: "Host", executor_kwargs):
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
raise NotImplementedError


class StringCommand(PyinfraCommand):
def __init__(self, *bits, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
*bits,
_separator=" ",
**arguments: Unpack[ConnectorArguments],
):
super().__init__(**arguments)
self.bits = bits
self.separator = kwargs.pop("_separator", " ")
self.separator = _separator

def __str__(self):
def __str__(self) -> str:
return self.get_masked_value()

def __repr__(self):
return "StringCommand({0})".format(self.get_masked_value())
def __repr__(self) -> str:
return f"StringCommand({self.get_masked_value()})"

def _get_all_bits(self, bit_accessor):
all_bits = []
Expand All @@ -91,7 +98,7 @@ def _get_all_bits(self, bit_accessor):
quote = False
if isinstance(bit, QuoteString):
quote = True
bit = bit.object
bit = bit.obj

if isinstance(bit, StringCommand):
bit = bit_accessor(bit)
Expand All @@ -106,35 +113,40 @@ def _get_all_bits(self, bit_accessor):

return all_bits

def get_raw_value(self):
def get_raw_value(self) -> str:
return self.separator.join(
self._get_all_bits(
lambda bit: bit.get_raw_value(),
),
)

def get_masked_value(self):
def get_masked_value(self) -> str:
return self.separator.join(
[
"***" if isinstance(bit, MaskString) else bit
for bit in self._get_all_bits(lambda bit: bit.get_masked_value())
],
)

def execute(self, state: "State", host: "Host", executor_kwargs):
executor_kwargs.update(self.executor_kwargs)
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

return host.run_shell_command(
self,
print_output=state.print_output,
print_input=state.print_input,
return_combined_output=True,
**executor_kwargs,
**connector_arguments,
)


class FileUploadCommand(PyinfraCommand):
def __init__(self, src: str, dest: str, remote_temp_filename=None, **kwargs):
def __init__(
self,
src: str,
dest: str,
remote_temp_filename=None,
**kwargs: Unpack[ConnectorArguments],
):
super().__init__(**kwargs)
self.src = src
self.dest = dest
Expand All @@ -143,21 +155,27 @@ def __init__(self, src: str, dest: str, remote_temp_filename=None, **kwargs):
def __repr__(self):
return "FileUploadCommand({0}, {1})".format(self.src, self.dest)

def execute(self, state: "State", host: "Host", executor_kwargs):
executor_kwargs.update(self.executor_kwargs)
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

return host.put_file(
self.src,
self.dest,
remote_temp_filename=self.remote_temp_filename,
print_output=state.print_output,
print_input=state.print_input,
**executor_kwargs,
**connector_arguments,
)


class FileDownloadCommand(PyinfraCommand):
def __init__(self, src: str, dest: str, remote_temp_filename=None, **kwargs):
def __init__(
self,
src: str,
dest: str,
remote_temp_filename=None,
**kwargs: Unpack[ConnectorArguments],
):
super().__init__(**kwargs)
self.src = src
self.dest = dest
Expand All @@ -166,21 +184,27 @@ def __init__(self, src: str, dest: str, remote_temp_filename=None, **kwargs):
def __repr__(self):
return "FileDownloadCommand({0}, {1})".format(self.src, self.dest)

def execute(self, state: "State", host: "Host", executor_kwargs):
executor_kwargs.update(self.executor_kwargs)
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
connector_arguments.update(self.connector_arguments)

return host.get_file(
self.src,
self.dest,
remote_temp_filename=self.remote_temp_filename,
print_output=state.print_output,
print_input=state.print_input,
**executor_kwargs,
**connector_arguments,
)


class FunctionCommand(PyinfraCommand):
def __init__(self, function, args, func_kwargs, **kwargs):
def __init__(
self,
function: Callable,
args,
func_kwargs,
**kwargs: Unpack[ConnectorArguments],
):
super().__init__(**kwargs)
self.function = function
self.args = args
Expand All @@ -193,7 +217,7 @@ def __repr__(self):
self.kwargs,
)

def execute(self, state: "State", host: "Host", executor_kwargs):
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
argspec = getfullargspec(self.function)
if "state" in argspec.args and "host" in argspec.args:
return self.function(state, host, *self.args, **self.kwargs)
Expand All @@ -208,7 +232,7 @@ def execute_function():


class RsyncCommand(PyinfraCommand):
def __init__(self, src: str, dest: str, flags, **kwargs):
def __init__(self, src: str, dest: str, flags, **kwargs: Unpack[ConnectorArguments]):
super().__init__(**kwargs)
self.src = src
self.dest = dest
Expand All @@ -217,12 +241,12 @@ def __init__(self, src: str, dest: str, flags, **kwargs):
def __repr__(self):
return "RsyncCommand({0}, {1}, {2})".format(self.src, self.dest, self.flags)

def execute(self, state: "State", host: "Host", executor_kwargs):
def execute(self, state: "State", host: "Host", connector_arguments: ConnectorArguments):
return host.rsync(
self.src,
self.dest,
self.flags,
print_output=state.print_output,
print_input=state.print_input,
**executor_kwargs,
**connector_arguments,
)
2 changes: 1 addition & 1 deletion pyinfra/api/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def connect_all(state: "State"):
# Raise any unexpected exception
greenlet.get()

if host.connection:
if host.connected:
state.activate_host(host)
else:
failed_hosts.add(host)
Expand Down
26 changes: 2 additions & 24 deletions pyinfra/api/connectors.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,8 @@
import pkg_resources


class BaseConnectorMeta:
handles_execution = False
keys_prefix = ""

class DataKeys:
pass

@classmethod
def keys(cls):
class Keys:
pass

for key in cls.DataKeys.__dict__:
if not key.startswith("_"):
setattr(Keys, key, f"{cls.keys_prefix}_{key}")

return Keys


def _load_connector(entrypoint):
connector = entrypoint.load()
if not getattr(connector, "Meta", None):
connector.Meta = BaseConnectorMeta
return connector
return entrypoint.load()


def get_all_connectors():
Expand All @@ -38,7 +16,7 @@ def get_execution_connectors():
return {
connector: connector_mod
for connector, connector_mod in get_all_connectors().items()
if connector_mod.Meta.handles_execution
if connector_mod.handles_execution
}


Expand Down
Loading

0 comments on commit 331865a

Please sign in to comment.