diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 6b7d202a88174..1190ed47f6bcc 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -23,6 +23,7 @@ import logging import os +import pipes import random import shutil import subprocess @@ -36,6 +37,9 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType from boto import ec2 +class UsageError(Exception): + pass + # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" @@ -103,11 +107,7 @@ def parse_args(): parser.print_help() sys.exit(1) (action, cluster_name) = args - if opts.identity_file == None and action in ['launch', 'login', 'start']: - print >> stderr, ("ERROR: The -i or --identity-file argument is " + - "required for " + action) - sys.exit(1) - + # Boto config check # http://boto.cloudhackers.com/en/latest/boto_config_tut.html home_dir = os.getenv('HOME') @@ -390,10 +390,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = master_nodes[0].public_dns_name if deploy_ssh_key: - print "Copying SSH key %s to master..." % opts.identity_file - ssh(master, opts, 'mkdir -p ~/.ssh') - scp(master, opts, opts.identity_file, '~/.ssh/id_rsa') - ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa') + print "Generating cluster's SSH key on master..." + key_setup = """ + [ -f ~/.ssh/id_rsa ] || + (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && + cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) + """ + ssh(master, opts, key_setup) + dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) + print "Transferring cluster's SSH key to slaves..." + for slave in slave_nodes: + print slave.public_dns_name + ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone'] @@ -535,18 +543,33 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): dest.write(text) dest.close() # rsync the whole directory over to the master machine - command = (("rsync -rv -e 'ssh -o StrictHostKeyChecking=no -i %s' " + - "'%s/' '%s@%s:/'") % (opts.identity_file, tmp_dir, opts.user, active_master)) - subprocess.check_call(command, shell=True) + command = [ + 'rsync', '-rv', + '-e', stringify_command(ssh_command(opts)), + "%s/" % tmp_dir, + "%s@%s:/" % (opts.user, active_master) + ] + subprocess.check_call(command) # Remove the temp directory we created above shutil.rmtree(tmp_dir) -# Copy a file to a given host through scp, throwing an exception if scp fails -def scp(host, opts, local_file, dest_file): - subprocess.check_call( - "scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" % - (opts.identity_file, local_file, opts.user, host, dest_file), shell=True) +def stringify_command(parts): + if isinstance(parts, str): + return parts + else: + return ' '.join(map(pipes.quote, parts)) + + +def ssh_args(opts): + parts = ['-o', 'StrictHostKeyChecking=no'] + if opts.identity_file is not None: + parts += ['-i', opts.identity_file] + return parts + + +def ssh_command(opts): + return ['ssh'] + ssh_args(opts) # Run a command on a host through ssh, retrying up to two times @@ -556,18 +579,42 @@ def ssh(host, opts, command): while True: try: return subprocess.check_call( - "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" % - (opts.identity_file, opts.user, host, command), shell=True) + ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: if (tries > 2): - raise e - print "Couldn't connect to host {0}, waiting 30 seconds".format(e) + # If this was an ssh failure, provide the user with hints. + if e.returncode == 255: + raise UsageError("Failed to SSH to remote host {0}.\nPlease check that you have provided the correct --identity-file and --key-pair parameters and try again.".format(host)) + else: + raise e + print >> stderr, "Error executing remote command, retrying after 30 seconds: {0}".format(e) time.sleep(30) tries = tries + 1 +def ssh_read(host, opts, command): + return subprocess.check_output( + ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) +def ssh_write(host, opts, command, input): + tries = 0 + while True: + proc = subprocess.Popen( + ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], + stdin=subprocess.PIPE) + proc.stdin.write(input) + proc.stdin.close() + status = proc.wait() + if status == 0: + break + elif (tries > 2): + raise RuntimeError("ssh_write failed with error %s" % proc.returncode) + else: + print >> stderr, "Error {0} while executing remote command, retrying after 30 seconds".format(status) + time.sleep(30) + tries = tries + 1 + # Gets a list of zones to launch instances in def get_zones(conn, opts): @@ -586,7 +633,7 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone -def main(): +def real_main(): (opts, action, cluster_name) = parse_args() try: conn = ec2.connect_to_region(opts.region) @@ -669,11 +716,11 @@ def main(): conn, opts, cluster_name) master = master_nodes[0].public_dns_name print "Logging into master " + master + "..." - proxy_opt = "" + proxy_opt = [] if opts.proxy_port != None: - proxy_opt = "-D " + opts.proxy_port - subprocess.check_call("ssh -o StrictHostKeyChecking=no -i %s %s %s@%s" % - (opts.identity_file, proxy_opt, opts.user, master), shell=True) + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)]) elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -715,6 +762,13 @@ def main(): sys.exit(1) +def main(): + try: + real_main() + except UsageError, e: + print >> stderr, "\nError:\n", e + + if __name__ == "__main__": logging.basicConfig() main()