Skip to content

Commit

Permalink
Merge pull request apache#670 from jey/ec2-ssh-improvements
Browse files Browse the repository at this point in the history
EC2 SSH improvements
  • Loading branch information
rxin committed Sep 26, 2013
2 parents c514cd1 + e86d1d4 commit 76677b8
Showing 1 changed file with 80 additions and 26 deletions.
106 changes: 80 additions & 26 deletions ec2/spark_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import logging
import os
import pipes
import random
import shutil
import subprocess
Expand All @@ -36,6 +37,9 @@
from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType
from boto import ec2

class UsageError(Exception):
pass

# A URL prefix from which to fetch AMI information
AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"

Expand Down Expand Up @@ -103,11 +107,7 @@ def parse_args():
parser.print_help()
sys.exit(1)
(action, cluster_name) = args
if opts.identity_file == None and action in ['launch', 'login', 'start']:
print >> stderr, ("ERROR: The -i or --identity-file argument is " +
"required for " + action)
sys.exit(1)


# Boto config check
# http://boto.cloudhackers.com/en/latest/boto_config_tut.html
home_dir = os.getenv('HOME')
Expand Down Expand Up @@ -390,10 +390,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
print "Copying SSH key %s to master..." % opts.identity_file
ssh(master, opts, 'mkdir -p ~/.ssh')
scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
print "Generating cluster's SSH key on master..."
key_setup = """
[ -f ~/.ssh/id_rsa ] ||
(ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
"""
ssh(master, opts, key_setup)
dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
print "Transferring cluster's SSH key to slaves..."
for slave in slave_nodes:
print slave.public_dns_name
ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)

modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
'mapreduce', 'spark-standalone']
Expand Down Expand Up @@ -535,18 +543,33 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
dest.write(text)
dest.close()
# rsync the whole directory over to the master machine
command = (("rsync -rv -e 'ssh -o StrictHostKeyChecking=no -i %s' " +
"'%s/' '%s@%s:/'") % (opts.identity_file, tmp_dir, opts.user, active_master))
subprocess.check_call(command, shell=True)
command = [
'rsync', '-rv',
'-e', stringify_command(ssh_command(opts)),
"%s/" % tmp_dir,
"%s@%s:/" % (opts.user, active_master)
]
subprocess.check_call(command)
# Remove the temp directory we created above
shutil.rmtree(tmp_dir)


# Copy a file to a given host through scp, throwing an exception if scp fails
def scp(host, opts, local_file, dest_file):
subprocess.check_call(
"scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" %
(opts.identity_file, local_file, opts.user, host, dest_file), shell=True)
def stringify_command(parts):
if isinstance(parts, str):
return parts
else:
return ' '.join(map(pipes.quote, parts))


def ssh_args(opts):
parts = ['-o', 'StrictHostKeyChecking=no']
if opts.identity_file is not None:
parts += ['-i', opts.identity_file]
return parts


def ssh_command(opts):
return ['ssh'] + ssh_args(opts)


# Run a command on a host through ssh, retrying up to two times
Expand All @@ -556,18 +579,42 @@ def ssh(host, opts, command):
while True:
try:
return subprocess.check_call(
"ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
(opts.identity_file, opts.user, host, command), shell=True)
ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)])
except subprocess.CalledProcessError as e:
if (tries > 2):
raise e
print "Couldn't connect to host {0}, waiting 30 seconds".format(e)
# If this was an ssh failure, provide the user with hints.
if e.returncode == 255:
raise UsageError("Failed to SSH to remote host {0}.\nPlease check that you have provided the correct --identity-file and --key-pair parameters and try again.".format(host))
else:
raise e
print >> stderr, "Error executing remote command, retrying after 30 seconds: {0}".format(e)
time.sleep(30)
tries = tries + 1


def ssh_read(host, opts, command):
return subprocess.check_output(
ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])


def ssh_write(host, opts, command, input):
tries = 0
while True:
proc = subprocess.Popen(
ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
stdin=subprocess.PIPE)
proc.stdin.write(input)
proc.stdin.close()
status = proc.wait()
if status == 0:
break
elif (tries > 2):
raise RuntimeError("ssh_write failed with error %s" % proc.returncode)
else:
print >> stderr, "Error {0} while executing remote command, retrying after 30 seconds".format(status)
time.sleep(30)
tries = tries + 1


# Gets a list of zones to launch instances in
def get_zones(conn, opts):
Expand All @@ -586,7 +633,7 @@ def get_partition(total, num_partitions, current_partitions):
return num_slaves_this_zone


def main():
def real_main():
(opts, action, cluster_name) = parse_args()
try:
conn = ec2.connect_to_region(opts.region)
Expand Down Expand Up @@ -669,11 +716,11 @@ def main():
conn, opts, cluster_name)
master = master_nodes[0].public_dns_name
print "Logging into master " + master + "..."
proxy_opt = ""
proxy_opt = []
if opts.proxy_port != None:
proxy_opt = "-D " + opts.proxy_port
subprocess.check_call("ssh -o StrictHostKeyChecking=no -i %s %s %s@%s" %
(opts.identity_file, proxy_opt, opts.user, master), shell=True)
proxy_opt = ['-D', opts.proxy_port]
subprocess.check_call(
ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)])

elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
Expand Down Expand Up @@ -715,6 +762,13 @@ def main():
sys.exit(1)


def main():
try:
real_main()
except UsageError, e:
print >> stderr, "\nError:\n", e


if __name__ == "__main__":
logging.basicConfig()
main()

0 comments on commit 76677b8

Please sign in to comment.