Skip to content

Commit

Permalink
Generate new SSH key for the cluster, make "--identity-file" optional
Browse files Browse the repository at this point in the history
  • Loading branch information
jey committed Sep 6, 2013
1 parent 6919a28 commit b98572c
Showing 1 changed file with 37 additions and 21 deletions.
58 changes: 37 additions & 21 deletions ec2/spark_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def parse_args():
parser.print_help()
sys.exit(1)
(action, cluster_name) = args
if opts.identity_file == None and action in ['launch', 'login', 'start']:
print >> stderr, ("ERROR: The -i or --identity-file argument is " +
"required for " + action)
sys.exit(1)


# Boto config check
# http://boto.cloudhackers.com/en/latest/boto_config_tut.html
home_dir = os.getenv('HOME')
Expand Down Expand Up @@ -392,10 +388,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
print "Copying SSH key %s to master..." % opts.identity_file
ssh(master, opts, 'mkdir -p ~/.ssh')
scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
print "Generating cluster's SSH key on master..."
key_setup = """
[ -f ~/.ssh/id_rsa ] ||
(ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
"""
ssh(master, opts, key_setup)
dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
print "Transferring cluster's SSH key to slaves..."
for slave in slave_nodes:
print slave.public_dns_name
ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)

modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
'mapreduce', 'spark-standalone']
Expand Down Expand Up @@ -556,24 +560,16 @@ def stringify_command(parts):


def ssh_args(opts):
parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file]
parts = ['-o', 'StrictHostKeyChecking=no']
if opts.identity_file is not None:
parts += ['-i', opts.identity_file]
return parts


def ssh_command(opts):
return ['ssh'] + ssh_args(opts)


def scp_command(opts):
return ['scp', '-q'] + ssh_args(opts)


# Copy a file to a given host through scp, throwing an exception if scp fails
def scp(host, opts, local_file, dest_file):
subprocess.check_call(
scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)])


# Run a command on a host through ssh, retrying up to two times
# and then throwing an exception if ssh continues to fail.
def ssh(host, opts, command):
Expand All @@ -585,13 +581,33 @@ def ssh(host, opts, command):
except subprocess.CalledProcessError as e:
if (tries > 2):
raise e
print "Couldn't connect to host {0}, waiting 30 seconds".format(e)
print "Error connecting to host, sleeping 30: {0}".format(e)
time.sleep(30)
tries = tries + 1


def ssh_read(host, opts, command):
return subprocess.check_output(
ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])


def ssh_write(host, opts, command, input):
tries = 0
while True:
proc = subprocess.Popen(
ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
stdin=subprocess.PIPE)
proc.stdin.write(input)
proc.stdin.close()
if proc.wait() == 0:
break
elif (tries > 2):
raise RuntimeError("ssh_write error %s" % proc.returncode)
else:
print "Error connecting to host, sleeping 30"
time.sleep(30)
tries = tries + 1


# Gets a list of zones to launch instances in
def get_zones(conn, opts):
Expand Down

0 comments on commit b98572c

Please sign in to comment.