Last active
February 14, 2017 18:22
-
-
Save chapmanb/0e9c0e1b65c25aa2f1777884bb28db0a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/toil/provisioners/aws/__init__.py b/src/toil/provisioners/aws/__init__.py | |
index c64e9d6..830e190 100644 | |
--- a/src/toil/provisioners/aws/__init__.py | |
+++ b/src/toil/provisioners/aws/__init__.py | |
@@ -42,7 +42,12 @@ def _getCurrentAWSZone(spotBid=None, nodeType=None, ctx=None): | |
pass | |
else: | |
zone = os.environ.get('TOIL_AWS_ZONE', None) | |
- if spotBid: | |
+ if not zone and runningOnEC2(): | |
+ try: | |
+ zone = get_instance_metadata()['placement']['availability-zone'] | |
+ except KeyError: | |
+ pass | |
+ if not zone and spotBid: | |
# if spot bid is present, all the other parameters must be as well | |
assert bool(spotBid) == bool(nodeType) == bool(ctx) | |
# if the zone is unset and we are using the spot market, optimize our | |
@@ -52,11 +57,6 @@ def _getCurrentAWSZone(spotBid=None, nodeType=None, ctx=None): | |
zone = boto.config.get('Boto', 'ec2_region_name') | |
if zone is not None: | |
zone += 'a' # derive an availability zone in the region | |
- if not zone and runningOnEC2(): | |
- try: | |
- zone = get_instance_metadata()['placement']['availability-zone'] | |
- except KeyError: | |
- pass | |
return zone | |
@@ -113,10 +113,13 @@ def choose_spot_zone(zones, bid, spot_history): | |
for zone in zones: | |
zone_histories = filter(lambda zone_history: | |
zone_history.availability_zone == zone.name, spot_history) | |
- price_deviation = std_dev([history.price for history in zone_histories]) | |
- recent_price = zone_histories[0] | |
+ if zone_histories: | |
+ price_deviation = std_dev([history.price for history in zone_histories]) | |
+ recent_price = zone_histories[0].price | |
+ else: | |
+ price_deviation, recent_price = 0.0, bid | |
zone_tuple = ZoneTuple(name=zone.name, price_deviation=price_deviation) | |
- (markets_over_bid, markets_under_bid)[recent_price.price < bid].append(zone_tuple) | |
+ (markets_over_bid, markets_under_bid)[recent_price < bid].append(zone_tuple) | |
return min(markets_under_bid or markets_over_bid, | |
key=attrgetter('price_deviation')).name | |
@@ -127,7 +130,8 @@ def optimize_spot_bid(ctx, instance_type, spot_bid): | |
Check whether the bid is sane and makes an effort to place the instance in a sensible zone. | |
""" | |
spot_history = _get_spot_history(ctx, instance_type) | |
- _check_spot_bid(spot_bid, spot_history) | |
+ if spot_history: | |
+ _check_spot_bid(spot_bid, spot_history) | |
zones = ctx.ec2.get_all_zones() | |
most_stable_zone = choose_spot_zone(zones, spot_bid, spot_history) | |
logger.info("Placing spot instances in zone %s.", most_stable_zone) | |
diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py | |
index 33a8f52..77debab 100644 | |
--- a/src/toil/provisioners/aws/awsProvisioner.py | |
+++ b/src/toil/provisioners/aws/awsProvisioner.py | |
@@ -26,6 +26,7 @@ from six import iteritems | |
from six.moves import xrange | |
from bd2k.util import memoize | |
+import boto.ec2 | |
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType | |
from boto.exception import BotoServerError, EC2ResponseError | |
from cgcloud.lib.ec2 import (ec2_instance_types, a_short_time, create_ondemand_instances, | |
@@ -48,7 +49,7 @@ class AWSProvisioner(AbstractProvisioner): | |
def __init__(self, config, batchSystem): | |
super(AWSProvisioner, self).__init__(config, batchSystem) | |
self.instanceMetaData = get_instance_metadata() | |
- self.clusterName = self.instanceMetaData['security-groups'] | |
+ self.clusterName = self._getClusterNameFromTags(self.instanceMetaData) | |
self.ctx = self._buildContext(clusterName=self.clusterName) | |
self.spotBid = None | |
assert config.preemptableNodeType or config.nodeType | |
@@ -63,6 +64,18 @@ class AWSProvisioner(AbstractProvisioner): | |
self.masterPublicKey = self.setSSH() | |
self.tags = self._getLeader(self.clusterName).tags | |
+ def _getClusterNameFromTags(self, md): | |
+ """Retrieve cluster name from current instance tags | |
+ """ | |
+ instance = self._getClusterInstance(md) | |
+ return str(instance.tags["Name"]) | |
+ | |
+ def _getClusterInstance(self, md): | |
+ zone = getCurrentAWSZone() | |
+ region = Context.availability_zone_re.match(zone).group(1) | |
+ conn = boto.ec2.connect_to_region(region) | |
+ return conn.get_all_instances(instance_ids=[md["instance-id"]])[0].instances[0] | |
+ | |
def setSSH(self): | |
if not os.path.exists('/root/.sshSuccess'): | |
subprocess.check_call(['ssh-keygen', '-f', '/root/.ssh/id_rsa', '-t', 'rsa', '-N', '']) | |
@@ -102,7 +115,7 @@ class AWSProvisioner(AbstractProvisioner): | |
logger.info('SSH ready') | |
kwargs['tty'] = sys.stdin.isatty() | |
command = args if args else ['bash'] | |
- cls._sshAppliance(leader.ip_address, *command, **kwargs) | |
+ cls._sshAppliance(leader.public_dns_name, *command, **kwargs) | |
def _remainingBillingInterval(self, instance): | |
return awsRemainingBillingInterval(instance) | |
@@ -111,7 +124,7 @@ class AWSProvisioner(AbstractProvisioner): | |
@memoize | |
def _discoverAMI(cls, ctx): | |
def descriptionMatches(ami): | |
- return ami.description is not None and 'stable 1068.9.0' in ami.description | |
+ return ami.description is not None and 'stable 1235.4.0' in ami.description | |
coreOSAMI = os.environ.get('TOIL_AWS_AMI') | |
if coreOSAMI is not None: | |
return coreOSAMI | |
@@ -193,7 +206,7 @@ class AWSProvisioner(AbstractProvisioner): | |
@classmethod | |
def rsyncLeader(cls, clusterName, args, zone=None): | |
leader = cls._getLeader(clusterName, zone=zone) | |
- cls._rsyncNode(leader.ip_address, args) | |
+ cls._rsyncNode(leader.public_dns_name, args) | |
@classmethod | |
def _rsyncNode(cls, ip, args, applianceName='toil_leader'): | |
@@ -250,7 +263,7 @@ class AWSProvisioner(AbstractProvisioner): | |
def _waitForNode(cls, instance, role): | |
# returns the node's IP | |
cls._waitForIP(instance) | |
- instanceIP = instance.ip_address | |
+ instanceIP = instance.public_dns_name | |
cls._waitForSSHPort(instanceIP) | |
cls._waitForSSHKeys(instanceIP) | |
# wait here so docker commands can be used reliably afterwards | |
@@ -281,8 +294,8 @@ class AWSProvisioner(AbstractProvisioner): | |
while True: | |
output = cls._sshInstance(ip_address, '/usr/bin/ps', 'aux') | |
time.sleep(5) | |
- if 'docker daemon' in output: | |
- # docker daemon has started | |
+ if 'docker daemon' in output or 'dockerd' in output: | |
+ # docker daemon has started, called `dockerd``` | |
break | |
else: | |
logger.info('... Still waiting...') | |
@@ -337,13 +350,14 @@ class AWSProvisioner(AbstractProvisioner): | |
s.close() | |
@classmethod | |
- def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None, userTags=None, zone=None): | |
+ def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None, userTags=None, zone=None, | |
+ vpcSubnet=None): | |
if userTags is None: | |
userTags = {} | |
ctx = cls._buildContext(clusterName=clusterName, zone=zone) | |
profileARN = cls._getProfileARN(ctx) | |
# the security group name is used as the cluster identifier | |
- cls._createSecurityGroup(ctx, clusterName) | |
+ sgs = cls._createSecurityGroup(ctx, clusterName, vpcSubnet) | |
bdm = cls._getBlockDeviceMapping(ec2_instance_types[instanceType]) | |
leaderData = dict(role='leader', | |
image=applianceSelf(), | |
@@ -351,10 +365,12 @@ class AWSProvisioner(AbstractProvisioner): | |
sshKey='AAAAB3NzaC1yc2Enoauthorizedkeyneeded', | |
args=leaderArgs.format(name=clusterName)) | |
userData = awsUserData.format(**leaderData) | |
- kwargs = {'key_name': keyName, 'security_groups': [clusterName], | |
+ kwargs = {'key_name': keyName, 'security_group_ids': [sg.id for sg in sgs], | |
'instance_type': instanceType, | |
'user_data': userData, 'block_device_map': bdm, | |
'instance_profile_arn': profileARN} | |
+ if vpcSubnet: | |
+ kwargs["subnet_id"] = vpcSubnet | |
if not spotBid: | |
logger.info('Launching non-preemptable leader') | |
create_ondemand_instances(ctx.ec2, image_id=cls._discoverAMI(ctx), | |
@@ -393,13 +409,18 @@ class AWSProvisioner(AbstractProvisioner): | |
logger.info('Deleting security group...') | |
for attempt in retry(timeout=300, predicate=expectedShutdownErrors): | |
with attempt: | |
- try: | |
- ctx.ec2.delete_security_group(name=clusterName) | |
- except BotoServerError as e: | |
- if e.error_code == 'InvalidGroup.NotFound': | |
- pass | |
- else: | |
- raise | |
+ sg_torm = None | |
+ for sg in ctx.ec2.get_all_security_groups(): | |
+ if sg.name == clusterName: | |
+ sg_torm = sg.id | |
+ if sg_torm: | |
+ try: | |
+ ctx.ec2.delete_security_group(group_id=sg_torm) | |
+ except BotoServerError as e: | |
+ if e.error_code == 'InvalidGroup.NotFound': | |
+ pass | |
+ else: | |
+ raise | |
logger.info('... Succesfully deleted security group') | |
else: | |
assert len(instances) > len(instancesToTerminate) | |
@@ -490,12 +511,14 @@ class AWSProvisioner(AbstractProvisioner): | |
sshKey=self.masterPublicKey, | |
args=workerArgs.format(ip=self.leaderIP, preemptable=preemptable, keyPath=keyPath)) | |
userData = awsUserData.format(**workerData) | |
+ sg_ids = [sg.id for sg in self.ctx.ec2.get_all_security_groups() if sg.name == self.clusterName] | |
kwargs = {'key_name': self.keyName, | |
- 'security_groups': [self.clusterName], | |
+ 'security_group_ids': sg_ids, | |
'instance_type': self.instanceType.name, | |
'user_data': userData, | |
'block_device_map': bdm, | |
'instance_profile_arn': arn} | |
+ kwargs["subnet_id"] = self._getClusterInstance(self.instanceMetaData).subnet_id | |
instancesLaunched = [] | |
@@ -581,14 +604,19 @@ class AWSProvisioner(AbstractProvisioner): | |
return [request for request in requests if request.id in idsToCancel] | |
@classmethod | |
- def _createSecurityGroup(cls, ctx, name): | |
+ def _createSecurityGroup(cls, ctx, name, vpcSubnet=None): | |
def groupNotFound(e): | |
retry = (e.status == 400 and 'does not exist in default VPC' in e.body) | |
return retry | |
- | |
+ vpc_id = None | |
+ if vpcSubnet: | |
+ vpc_conn = boto.connect_vpc(region=ctx.ec2.region) | |
+ subnets = vpc_conn.get_all_subnets(subnet_ids=[vpcSubnet]) | |
+ if len(subnets) > 0: | |
+ vpc_id = subnets[0].vpc_id | |
# security group create/get. ssh + all ports open within the group | |
try: | |
- web = ctx.ec2.create_security_group(name, 'Toil appliance security group') | |
+ web = ctx.ec2.create_security_group(name, 'Toil appliance security group', vpc_id=vpc_id) | |
except EC2ResponseError as e: | |
if e.status == 400 and 'already exists' in e.body: | |
pass # group exists- nothing to do | |
@@ -603,6 +631,11 @@ class AWSProvisioner(AbstractProvisioner): | |
with attempt: | |
# the following authorizes all port access within the web security group | |
web.authorize(ip_protocol='tcp', from_port=0, to_port=65535, src_group=web) | |
+ out_sgs = [] | |
+ for sg in ctx.ec2.get_all_security_groups(): | |
+ if vpc_id is None or sg.vpc_id == vpc_id: | |
+ out_sgs.append(sg) | |
+ return out_sgs | |
@classmethod | |
def _getProfileARN(cls, ctx): | |
diff --git a/src/toil/utils/toilLaunchCluster.py b/src/toil/utils/toilLaunchCluster.py | |
index ef7c254..73f5083 100644 | |
--- a/src/toil/utils/toilLaunchCluster.py | |
+++ b/src/toil/utils/toilLaunchCluster.py | |
@@ -49,6 +49,9 @@ def main(): | |
" \"Name\": clusterName," | |
" \"Owner\": IAM username" | |
" }. ") | |
+ parser.add_argument("--vpcSubnet", | |
+ help="VPC subnet ID to launch cluster in. Uses default subnet if not specified." | |
+ "This subnet needs to have auto assign IPs turned on.") | |
config = parseBasicOptions(parser) | |
setLoggingFromOptions(config) | |
tagsDict = None if config.tags is None else createTagsDict(config.tags) | |
@@ -70,4 +73,5 @@ def main(): | |
assert False | |
provisioner.launchCluster(instanceType=config.nodeType, clusterName=config.clusterName, | |
- keyName=config.keyPairName, spotBid=spotBid, userTags=tagsDict, zone=config.zone) | |
+ keyName=config.keyPairName, spotBid=spotBid, userTags=tagsDict, zone=config.zone, | |
+ vpcSubnet=config.vpcSubnet) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment