Skip to content

Commit

Permalink
Move port and username into sftp connection string
Browse files Browse the repository at this point in the history
  • Loading branch information
soerface committed May 15, 2024
1 parent eb6bcdc commit 0d04311
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ def write_to_fs(csv_path: Path, df: DataFrame):
])


def write_to_sftp(csv_path: str, df: DataFrame, ssh_key_path: str, ssh_username: str, ssh_port: int = 22):
def write_to_sftp(csv_path: str, df: DataFrame, ssh_key_path: str):
assert csv_path.startswith("sftp://")
sftp_url = csv_path[len("sftp://"):]
hostname, _, path = sftp_url.partition("/")
transport = paramiko.Transport((hostname, ssh_port))
hostname, _, ssh_port = hostname.partition(":")
username, _, hostname = hostname.partition("@")

transport = paramiko.Transport((hostname, int(ssh_port) or 22))
try:
transport.connect(username=ssh_username, pkey=paramiko.RSAKey.from_private_key_file(ssh_key_path))
transport.connect(username=username, pkey=paramiko.RSAKey.from_private_key_file(ssh_key_path))
except paramiko.ssh_exception.AuthenticationException:
logger.error("Authentication failed. Check your credentials")
sys.exit(1)
Expand Down Expand Up @@ -86,9 +89,9 @@ def write_to_sftp(csv_path: str, df: DataFrame, ssh_key_path: str, ssh_username:
transport.close()


def download_data(hostname: str, csv_path: str, ssh_key_path: str | None = None, ssh_username: str | None = None, ssh_port: int = 22):
if csv_path.startswith("sftp://") and not (ssh_key_path and ssh_username):
raise ValueError("ssh_key_path and ssh_username must be provided when using SFTP")
def download_data(hostname: str, csv_path: str, ssh_key_path: str | None = None):
if csv_path.startswith("sftp://") and not ssh_key_path:
raise ValueError("ssh_key_path must be provided when using SFTP")

# phase_url = f"http://{hostname}/emeter/%d/em_data.csv"
phase_url = f"/tmp/tmp.zNG32agH1t/em_data.%d.csv"
Expand All @@ -106,7 +109,7 @@ def download_data(hostname: str, csv_path: str, ssh_key_path: str | None = None,

df["Date"] = df["Date/time UTC"].dt.date
if csv_path.startswith("sftp://"):
write_to_sftp(csv_path, df, ssh_key_path, ssh_username, ssh_port)
write_to_sftp(csv_path, df, ssh_key_path)
else:
write_to_fs(Path(csv_path), df)

Expand Down Expand Up @@ -257,7 +260,6 @@ def main(
help="Do not plot, only download"
)
args.add_argument("--ssh-key-path", type=str, help="Path to the SSH key for SFTP")
args.add_argument("--ssh-username", type=str, help="Username for SFTP")
args.add_argument("--ssh-port", type=int, help="Port for SFTP", default=22)
args.add_argument("--host", type=str, help="Hostname or IP address of the Shelly device", default="192.168.178.99")
args.add_argument("--sample-rate", type=str, help="Sample rate for the data", default="1min")
Expand All @@ -278,7 +280,7 @@ def main(
logger.info("Loglevel set to %s", logging.getLevelName(logger.getEffectiveLevel()))

if args.download or args.download_only:
download_data(args.host, args.csv_path, args.ssh_key_path, args.ssh_username, args.ssh_port)
download_data(args.host, args.csv_path, args.ssh_key_path)
if not args.download_only:
if not args.output_path:
logger.error("--output_path is required when plotting")
Expand Down

0 comments on commit 0d04311

Please sign in to comment.