|
|
|
@ -1,8 +1,12 @@ |
|
|
|
import time |
|
|
|
from os.path import basename |
|
|
|
from threading import Thread |
|
|
|
from Queue import Queue |
|
|
|
|
|
|
|
import boto.ec2 |
|
|
|
from boto.manage.cmdshell import sshclient_from_instance |
|
|
|
import paramiko |
|
|
|
|
|
|
|
from employ.logger import logger |
|
|
|
from employ.exceptions import SSHConnectionError |
|
|
|
from employ.managers import Manager |
|
|
|
|
|
|
|
@ -28,15 +32,32 @@ class EC2Manager(Manager): |
|
|
|
; all instances have the state "running", this interval |
|
|
|
; is how long the manager will wait between checking states |
|
|
|
wait_interval = 5 |
|
|
|
|
|
|
|
; when attempting to gain an ssh connection, fail after |
|
|
|
; connection_attempts attempts |
|
|
|
connection_attempts = 10 |
|
|
|
""" |
|
|
|
name = "ec2" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, ami_image_id="ami-da0cf8b3", num_instances=1, instance_name="employed", |
|
|
|
region="us-east-1", instance_type="t1.micro", key_name=None, security_group="default", |
|
|
|
user_name="root", host_key="~/.ssh/known_hosts", ssh_pwd=None, wait_interval=5 |
|
|
|
self, ami_image_id="ami-da0cf8b3", num_instances=1, |
|
|
|
instance_name="employed", region="us-east-1", |
|
|
|
instance_type="t1.micro", key_name=None, |
|
|
|
security_group="default", user_name="root", |
|
|
|
host_key="~/.ssh/known_hosts", ssh_pwd=None, |
|
|
|
wait_interval=5, connection_attempts=10 |
|
|
|
): |
|
|
|
""" |
|
|
|
Construct for :class:``employ.managers.EC2Manager`` |
|
|
|
|
|
|
|
:param ami_image_id: the ec2 ami image to use |
|
|
|
:type ami_image_id: str |
|
|
|
:param num_instances: the number of ec2 instances to start |
|
|
|
:type num_instances: int |
|
|
|
:param instance_name: the name to assign to each instance |
|
|
|
""" |
|
|
|
self.instances = [] |
|
|
|
self.client_connections = [] |
|
|
|
self.ami_image_id = ami_image_id |
|
|
|
self.num_instances = num_instances |
|
|
|
self.instance_name = instance_name |
|
|
|
@ -48,6 +69,7 @@ class EC2Manager(Manager): |
|
|
|
self.host_key = host_key |
|
|
|
self.ssh_pwd = ssh_pwd |
|
|
|
self.wait_interval = wait_interval |
|
|
|
self.connection_attempts = connection_attempts |
|
|
|
self._connection = None |
|
|
|
|
|
|
|
@classmethod |
|
|
|
@ -64,7 +86,7 @@ class EC2Manager(Manager): |
|
|
|
return [instance.id for instance in self.instances] |
|
|
|
|
|
|
|
def setup_instances(self): |
|
|
|
print "starting instances" |
|
|
|
logger.info("starting %s instances", self.num_instances) |
|
|
|
connection = self.connection() |
|
|
|
reservation = connection.run_instances( |
|
|
|
image_id=self.ami_image_id, |
|
|
|
@ -77,25 +99,100 @@ class EC2Manager(Manager): |
|
|
|
self.instances = reservation.instances |
|
|
|
connection.create_tags(self.instance_ids(), {"Name": self.instance_name}) |
|
|
|
|
|
|
|
print "waiting until they are all running" |
|
|
|
logger.info("waiting until all instances are all 'running'") |
|
|
|
while not all(instance.update() == "running" for instance in self.instances): |
|
|
|
time.sleep(self.wait_interval) |
|
|
|
|
|
|
|
__enter__ = setup_instances |
|
|
|
# connections usually take a bit, might as well wait |
|
|
|
# a little bit before making the first attempt |
|
|
|
time.sleep(self.wait_interval) |
|
|
|
logger.info("establishing ssh connections") |
|
|
|
for instance in self.instances: |
|
|
|
for _ in xrange(self.connection_attempts): |
|
|
|
logger.info( |
|
|
|
"Attempting connection to %s@%s", |
|
|
|
self.user_name, instance.ip_address |
|
|
|
) |
|
|
|
client = paramiko.SSHClient() |
|
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
|
|
|
try: |
|
|
|
client.connect( |
|
|
|
instance.ip_address, username=self.user_name, |
|
|
|
key_filename=self.host_key |
|
|
|
) |
|
|
|
self.client_connections.append(client) |
|
|
|
break |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
time.sleep(self.wait_interval) |
|
|
|
else: |
|
|
|
raise SSHConnectionError( |
|
|
|
"Could not establish ssh connection to %s@%s after %s attempts", |
|
|
|
self.user_name, instance.ip_address, self.connection_attempts |
|
|
|
) |
|
|
|
|
|
|
|
def cleanup_instances(self): |
|
|
|
for client in self.client_connections: |
|
|
|
client.close() |
|
|
|
|
|
|
|
connection = self.connection() |
|
|
|
connection.terminate_instances(instance_ids=self.instance_ids()) |
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback): |
|
|
|
self.cleanup_instances() |
|
|
|
|
|
|
|
def setup(self, script): |
|
|
|
print "setup script: %s" % script |
|
|
|
remote_file = "/tmp/%s" % basename(script) |
|
|
|
workers = [] |
|
|
|
for client in self.client_connections: |
|
|
|
worker = Thread(target=self._put_file, args=(client, script, remote_file)) |
|
|
|
worker.daemon = True |
|
|
|
worker.start() |
|
|
|
workers.append(worker) |
|
|
|
for worker in workers: |
|
|
|
worker.join() |
|
|
|
|
|
|
|
command = "/bin/sh %s" % remote_file |
|
|
|
results = self._run_multi(command) |
|
|
|
self.validate_results(results, command) |
|
|
|
|
|
|
|
def run(self, command): |
|
|
|
# shell = sshclient_from_instance( |
|
|
|
# self.instances[0], self.host_key, user_name=self.user_name |
|
|
|
# ) |
|
|
|
# command.aggregate(shell.run(command.command())) |
|
|
|
print "running command: %s" % command.command() |
|
|
|
execute = command.command() |
|
|
|
results = self._run_multi(execute) |
|
|
|
self.validate_results(results, execute) |
|
|
|
command.aggregate(results) |
|
|
|
|
|
|
|
def _run_command(self, client, command, results): |
|
|
|
transport = client.get_transport() |
|
|
|
channel = transport.open_session() |
|
|
|
logger.info("executing command %s", command) |
|
|
|
channel.get_pty() |
|
|
|
channel.exec_command(command) |
|
|
|
status = int(channel.recv_exit_status()) |
|
|
|
stdout = "" |
|
|
|
while channel.recv_ready(): |
|
|
|
stdout += channel.recv(1024) |
|
|
|
stderr = "" |
|
|
|
while channel.recv_stderr_ready(): |
|
|
|
stderr += channel.recv_stderr(1024) |
|
|
|
results.put((status, stdout, stderr)) |
|
|
|
|
|
|
|
def _run_multi(self, command): |
|
|
|
results = Queue() |
|
|
|
workers = [] |
|
|
|
for client in self.client_connections: |
|
|
|
worker = Thread(target=self._run_command, args=(client, command, results)) |
|
|
|
worker.daemon = True |
|
|
|
worker.start() |
|
|
|
workers.append(worker) |
|
|
|
|
|
|
|
for worker in workers: |
|
|
|
worker.join() |
|
|
|
|
|
|
|
all_results = [] |
|
|
|
while not results.empty(): |
|
|
|
all_results.append(results.get()) |
|
|
|
return all_results |
|
|
|
|
|
|
|
def _put_file(self, client, script, remote_file): |
|
|
|
fp = open(script, "r") |
|
|
|
transport = client.get_transport() |
|
|
|
sftp_client = paramiko.SFTPClient.from_transport(transport) |
|
|
|
sftp_client.putfo(fp, remote_file) |