Skip to content

Commit

Permalink
Construct shell commands as sequences for safety and composability
Browse files Browse the repository at this point in the history
  • Loading branch information
jey committed Sep 6, 2013
1 parent 1e15feb commit 6919a28
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions ec2/spark_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import logging
import os
import pipes
import random
import shutil
import subprocess
Expand Down Expand Up @@ -536,18 +537,41 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
dest.write(text)
dest.close()
# rsync the whole directory over to the master machine
command = (("rsync -rv -e 'ssh -o StrictHostKeyChecking=no -i %s' " +
"'%s/' '%s@%s:/'") % (opts.identity_file, tmp_dir, opts.user, active_master))
subprocess.check_call(command, shell=True)
command = [
'rsync', '-rv',
'-e', stringify_command(ssh_command(opts)),
"%s/" % tmp_dir,
"%s@%s:/" % (opts.user, active_master)
]
subprocess.check_call(command)
# Remove the temp directory we created above
shutil.rmtree(tmp_dir)


def stringify_command(parts):
if isinstance(parts, str):
return parts
else:
return ' '.join(map(pipes.quote, parts))


def ssh_args(opts):
parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file]
return parts


def ssh_command(opts):
return ['ssh'] + ssh_args(opts)


def scp_command(opts):
return ['scp', '-q'] + ssh_args(opts)


# Copy a file to a given host through scp, throwing an exception if scp fails
def scp(host, opts, local_file, dest_file):
subprocess.check_call(
"scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" %
(opts.identity_file, local_file, opts.user, host, dest_file), shell=True)
scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)])


# Run a command on a host through ssh, retrying up to two times
Expand All @@ -557,8 +581,7 @@ def ssh(host, opts, command):
while True:
try:
return subprocess.check_call(
"ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
(opts.identity_file, opts.user, host, command), shell=True)
ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)])
except subprocess.CalledProcessError as e:
if (tries > 2):
raise e
Expand Down Expand Up @@ -670,11 +693,11 @@ def main():
conn, opts, cluster_name)
master = master_nodes[0].public_dns_name
print "Logging into master " + master + "..."
proxy_opt = ""
proxy_opt = []
if opts.proxy_port != None:
proxy_opt = "-D " + opts.proxy_port
subprocess.check_call("ssh -o StrictHostKeyChecking=no -i %s %s %s@%s" %
(opts.identity_file, proxy_opt, opts.user, master), shell=True)
proxy_opt = ['-D', opts.proxy_port]
subprocess.check_call(
ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)])

elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
Expand Down

0 comments on commit 6919a28

Please sign in to comment.