From 8b41767380a133317bb98a9df4a2afb383bac7cc Mon Sep 17 00:00:00 2001 From: brettlangdon Date: Fri, 27 Sep 2013 20:39:39 -0400 Subject: [PATCH] make ec2 manager work --- employ/managers/__init__.py | 17 +++++ employ/managers/ec2.py | 129 +++++++++++++++++++++++++++++++----- requirements.txt | 1 + 3 files changed, 131 insertions(+), 16 deletions(-) diff --git a/employ/managers/__init__.py b/employ/managers/__init__.py index 44bb50a..5cb8a5a 100644 --- a/employ/managers/__init__.py +++ b/employ/managers/__init__.py @@ -30,3 +30,20 @@ class Manager(object): def run(self, command): raise NotImplementedError() + + def validate_results(self, results, command): + """ + Helper method to validate the results of running commands. + + :param results: the (status, stdout, stderr) results from running `command` + :type results: list + :param command: the raw str command that was run + :type command: str + :raises: :class:``employ.exections.ExecutionError`` + """ + for status, stdout, stderr in results: + if status != 0: + raise ExecutionError( + "Non-Zero status code from executing command: %s" % command, + command, status, stdout, stderr, + ) diff --git a/employ/managers/ec2.py b/employ/managers/ec2.py index 68b1b1c..f1fc7c8 100644 --- a/employ/managers/ec2.py +++ b/employ/managers/ec2.py @@ -1,8 +1,12 @@ import time +from os.path import basename +from threading import Thread +from Queue import Queue import boto.ec2 -from boto.manage.cmdshell import sshclient_from_instance +import paramiko +from employ.logger import logger from employ.exceptions import SSHConnectionError from employ.managers import Manager @@ -28,15 +32,32 @@ class EC2Manager(Manager): ; all instances have the state "running", this interval ; is how long the manager will wait between checking states wait_interval = 5 + + ; when attempting to gain an ssh connection, fail after + ; connection_attempts attempts + connection_attempts = 10 """ name = "ec2" def __init__( - self, ami_image_id="ami-da0cf8b3", num_instances=1, instance_name="employed", - region="us-east-1", instance_type="t1.micro", key_name=None, security_group="default", - user_name="root", host_key="~/.ssh/known_hosts", ssh_pwd=None, wait_interval=5 + self, ami_image_id="ami-da0cf8b3", num_instances=1, + instance_name="employed", region="us-east-1", + instance_type="t1.micro", key_name=None, + security_group="default", user_name="root", + host_key="~/.ssh/known_hosts", ssh_pwd=None, + wait_interval=5, connection_attempts=10 ): + """ + Construct for :class:``employ.managers.EC2Manager`` + + :param ami_image_id: the ec2 ami image to use + :type ami_image_id: str + :param num_instances: the number of ec2 instances to start + :type num_instances: int + :param instance_name: the name to assign to each instance + """ self.instances = [] + self.client_connections = [] self.ami_image_id = ami_image_id self.num_instances = num_instances self.instance_name = instance_name @@ -48,6 +69,7 @@ class EC2Manager(Manager): self.host_key = host_key self.ssh_pwd = ssh_pwd self.wait_interval = wait_interval + self.connection_attempts = connection_attempts self._connection = None @classmethod @@ -64,7 +86,7 @@ class EC2Manager(Manager): return [instance.id for instance in self.instances] def setup_instances(self): - print "starting instances" + logger.info("starting %s instances", self.num_instances) connection = self.connection() reservation = connection.run_instances( image_id=self.ami_image_id, @@ -77,25 +99,100 @@ class EC2Manager(Manager): self.instances = reservation.instances connection.create_tags(self.instance_ids(), {"Name": self.instance_name}) - print "waiting until they are all running" + logger.info("waiting until all instances are all 'running'") while not all(instance.update() == "running" for instance in self.instances): time.sleep(self.wait_interval) - __enter__ = setup_instances + # connections usually take a bit, might as well wait + # a little bit before making the first attempt + time.sleep(self.wait_interval) + logger.info("establishing ssh connections") + for instance in self.instances: + for _ in xrange(self.connection_attempts): + logger.info( + "Attempting connection to %s@%s", + self.user_name, instance.ip_address + ) + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + client.connect( + instance.ip_address, username=self.user_name, + key_filename=self.host_key + ) + self.client_connections.append(client) + break + except Exception: + pass + time.sleep(self.wait_interval) + else: + raise SSHConnectionError( + "Could not establish ssh connection to %s@%s after %s attempts", + self.user_name, instance.ip_address, self.connection_attempts + ) def cleanup_instances(self): + for client in self.client_connections: + client.close() + connection = self.connection() connection.terminate_instances(instance_ids=self.instance_ids()) - def __exit__(self, type, value, traceback): - self.cleanup_instances() - def setup(self, script): - print "setup script: %s" % script + remote_file = "/tmp/%s" % basename(script) + workers = [] + for client in self.client_connections: + worker = Thread(target=self._put_file, args=(client, script, remote_file)) + worker.daemon = True + worker.start() + workers.append(worker) + for worker in workers: + worker.join() + + command = "/bin/sh %s" % remote_file + results = self._run_multi(command) + self.validate_results(results, command) def run(self, command): - # shell = sshclient_from_instance( - # self.instances[0], self.host_key, user_name=self.user_name - # ) - # command.aggregate(shell.run(command.command())) - print "running command: %s" % command.command() + execute = command.command() + results = self._run_multi(execute) + self.validate_results(results, execute) + command.aggregate(results) + + def _run_command(self, client, command, results): + transport = client.get_transport() + channel = transport.open_session() + logger.info("executing command %s", command) + channel.get_pty() + channel.exec_command(command) + status = int(channel.recv_exit_status()) + stdout = "" + while channel.recv_ready(): + stdout += channel.recv(1024) + stderr = "" + while channel.recv_stderr_ready(): + stderr += channel.recv_stderr(1024) + results.put((status, stdout, stderr)) + + def _run_multi(self, command): + results = Queue() + workers = [] + for client in self.client_connections: + worker = Thread(target=self._run_command, args=(client, command, results)) + worker.daemon = True + worker.start() + workers.append(worker) + + for worker in workers: + worker.join() + + all_results = [] + while not results.empty(): + all_results.append(results.get()) + return all_results + + def _put_file(self, client, script, remote_file): + fp = open(script, "r") + transport = client.get_transport() + sftp_client = paramiko.SFTPClient.from_transport(transport) + sftp_client.putfo(fp, remote_file) diff --git a/requirements.txt b/requirements.txt index 76862c6..85be3e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ boto>=2.13.0 docopt>=0.6.0 +paramiko>=1.11.0 straight.plugin>=1.4.0