Skip to content

Instantly share code, notes, and snippets.

@chapmanb
Last active February 14, 2017 18:22
Show Gist options
  • Save chapmanb/0e9c0e1b65c25aa2f1777884bb28db0a to your computer and use it in GitHub Desktop.
Save chapmanb/0e9c0e1b65c25aa2f1777884bb28db0a to your computer and use it in GitHub Desktop.
diff --git a/src/toil/provisioners/aws/__init__.py b/src/toil/provisioners/aws/__init__.py
index c64e9d6..830e190 100644
--- a/src/toil/provisioners/aws/__init__.py
+++ b/src/toil/provisioners/aws/__init__.py
@@ -42,7 +42,12 @@ def _getCurrentAWSZone(spotBid=None, nodeType=None, ctx=None):
pass
else:
zone = os.environ.get('TOIL_AWS_ZONE', None)
- if spotBid:
+ if not zone and runningOnEC2():
+ try:
+ zone = get_instance_metadata()['placement']['availability-zone']
+ except KeyError:
+ pass
+ if not zone and spotBid:
# if spot bid is present, all the other parameters must be as well
assert bool(spotBid) == bool(nodeType) == bool(ctx)
# if the zone is unset and we are using the spot market, optimize our
@@ -52,11 +57,6 @@ def _getCurrentAWSZone(spotBid=None, nodeType=None, ctx=None):
zone = boto.config.get('Boto', 'ec2_region_name')
if zone is not None:
zone += 'a' # derive an availability zone in the region
- if not zone and runningOnEC2():
- try:
- zone = get_instance_metadata()['placement']['availability-zone']
- except KeyError:
- pass
return zone
@@ -113,10 +113,13 @@ def choose_spot_zone(zones, bid, spot_history):
for zone in zones:
zone_histories = filter(lambda zone_history:
zone_history.availability_zone == zone.name, spot_history)
- price_deviation = std_dev([history.price for history in zone_histories])
- recent_price = zone_histories[0]
+ if zone_histories:
+ price_deviation = std_dev([history.price for history in zone_histories])
+ recent_price = zone_histories[0].price
+ else:
+ price_deviation, recent_price = 0.0, bid
zone_tuple = ZoneTuple(name=zone.name, price_deviation=price_deviation)
- (markets_over_bid, markets_under_bid)[recent_price.price < bid].append(zone_tuple)
+ (markets_over_bid, markets_under_bid)[recent_price < bid].append(zone_tuple)
return min(markets_under_bid or markets_over_bid,
key=attrgetter('price_deviation')).name
@@ -127,7 +130,8 @@ def optimize_spot_bid(ctx, instance_type, spot_bid):
Check whether the bid is sane and makes an effort to place the instance in a sensible zone.
"""
spot_history = _get_spot_history(ctx, instance_type)
- _check_spot_bid(spot_bid, spot_history)
+ if spot_history:
+ _check_spot_bid(spot_bid, spot_history)
zones = ctx.ec2.get_all_zones()
most_stable_zone = choose_spot_zone(zones, spot_bid, spot_history)
logger.info("Placing spot instances in zone %s.", most_stable_zone)
diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py
index 33a8f52..77debab 100644
--- a/src/toil/provisioners/aws/awsProvisioner.py
+++ b/src/toil/provisioners/aws/awsProvisioner.py
@@ -26,6 +26,7 @@ from six import iteritems
from six.moves import xrange
from bd2k.util import memoize
+import boto.ec2
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
from boto.exception import BotoServerError, EC2ResponseError
from cgcloud.lib.ec2 import (ec2_instance_types, a_short_time, create_ondemand_instances,
@@ -48,7 +49,7 @@ class AWSProvisioner(AbstractProvisioner):
def __init__(self, config, batchSystem):
super(AWSProvisioner, self).__init__(config, batchSystem)
self.instanceMetaData = get_instance_metadata()
- self.clusterName = self.instanceMetaData['security-groups']
+ self.clusterName = self._getClusterNameFromTags(self.instanceMetaData)
self.ctx = self._buildContext(clusterName=self.clusterName)
self.spotBid = None
assert config.preemptableNodeType or config.nodeType
@@ -63,6 +64,18 @@ class AWSProvisioner(AbstractProvisioner):
self.masterPublicKey = self.setSSH()
self.tags = self._getLeader(self.clusterName).tags
+ def _getClusterNameFromTags(self, md):
+ """Retrieve cluster name from current instance tags
+ """
+ instance = self._getClusterInstance(md)
+ return str(instance.tags["Name"])
+
+ def _getClusterInstance(self, md):
+ zone = getCurrentAWSZone()
+ region = Context.availability_zone_re.match(zone).group(1)
+ conn = boto.ec2.connect_to_region(region)
+ return conn.get_all_instances(instance_ids=[md["instance-id"]])[0].instances[0]
+
def setSSH(self):
if not os.path.exists('/root/.sshSuccess'):
subprocess.check_call(['ssh-keygen', '-f', '/root/.ssh/id_rsa', '-t', 'rsa', '-N', ''])
@@ -102,7 +115,7 @@ class AWSProvisioner(AbstractProvisioner):
logger.info('SSH ready')
kwargs['tty'] = sys.stdin.isatty()
command = args if args else ['bash']
- cls._sshAppliance(leader.ip_address, *command, **kwargs)
+ cls._sshAppliance(leader.public_dns_name, *command, **kwargs)
def _remainingBillingInterval(self, instance):
return awsRemainingBillingInterval(instance)
@@ -111,7 +124,7 @@ class AWSProvisioner(AbstractProvisioner):
@memoize
def _discoverAMI(cls, ctx):
def descriptionMatches(ami):
- return ami.description is not None and 'stable 1068.9.0' in ami.description
+ return ami.description is not None and 'stable 1235.4.0' in ami.description
coreOSAMI = os.environ.get('TOIL_AWS_AMI')
if coreOSAMI is not None:
return coreOSAMI
@@ -193,7 +206,7 @@ class AWSProvisioner(AbstractProvisioner):
@classmethod
def rsyncLeader(cls, clusterName, args, zone=None):
leader = cls._getLeader(clusterName, zone=zone)
- cls._rsyncNode(leader.ip_address, args)
+ cls._rsyncNode(leader.public_dns_name, args)
@classmethod
def _rsyncNode(cls, ip, args, applianceName='toil_leader'):
@@ -250,7 +263,7 @@ class AWSProvisioner(AbstractProvisioner):
def _waitForNode(cls, instance, role):
# returns the node's IP
cls._waitForIP(instance)
- instanceIP = instance.ip_address
+ instanceIP = instance.public_dns_name
cls._waitForSSHPort(instanceIP)
cls._waitForSSHKeys(instanceIP)
# wait here so docker commands can be used reliably afterwards
@@ -281,8 +294,8 @@ class AWSProvisioner(AbstractProvisioner):
while True:
output = cls._sshInstance(ip_address, '/usr/bin/ps', 'aux')
time.sleep(5)
- if 'docker daemon' in output:
- # docker daemon has started
+ if 'docker daemon' in output or 'dockerd' in output:
+ # docker daemon has started, called `dockerd```
break
else:
logger.info('... Still waiting...')
@@ -337,13 +350,14 @@ class AWSProvisioner(AbstractProvisioner):
s.close()
@classmethod
- def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None, userTags=None, zone=None):
+ def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None, userTags=None, zone=None,
+ vpcSubnet=None):
if userTags is None:
userTags = {}
ctx = cls._buildContext(clusterName=clusterName, zone=zone)
profileARN = cls._getProfileARN(ctx)
# the security group name is used as the cluster identifier
- cls._createSecurityGroup(ctx, clusterName)
+ sgs = cls._createSecurityGroup(ctx, clusterName, vpcSubnet)
bdm = cls._getBlockDeviceMapping(ec2_instance_types[instanceType])
leaderData = dict(role='leader',
image=applianceSelf(),
@@ -351,10 +365,12 @@ class AWSProvisioner(AbstractProvisioner):
sshKey='AAAAB3NzaC1yc2Enoauthorizedkeyneeded',
args=leaderArgs.format(name=clusterName))
userData = awsUserData.format(**leaderData)
- kwargs = {'key_name': keyName, 'security_groups': [clusterName],
+ kwargs = {'key_name': keyName, 'security_group_ids': [sg.id for sg in sgs],
'instance_type': instanceType,
'user_data': userData, 'block_device_map': bdm,
'instance_profile_arn': profileARN}
+ if vpcSubnet:
+ kwargs["subnet_id"] = vpcSubnet
if not spotBid:
logger.info('Launching non-preemptable leader')
create_ondemand_instances(ctx.ec2, image_id=cls._discoverAMI(ctx),
@@ -393,13 +409,18 @@ class AWSProvisioner(AbstractProvisioner):
logger.info('Deleting security group...')
for attempt in retry(timeout=300, predicate=expectedShutdownErrors):
with attempt:
- try:
- ctx.ec2.delete_security_group(name=clusterName)
- except BotoServerError as e:
- if e.error_code == 'InvalidGroup.NotFound':
- pass
- else:
- raise
+ sg_torm = None
+ for sg in ctx.ec2.get_all_security_groups():
+ if sg.name == clusterName:
+ sg_torm = sg.id
+ if sg_torm:
+ try:
+ ctx.ec2.delete_security_group(group_id=sg_torm)
+ except BotoServerError as e:
+ if e.error_code == 'InvalidGroup.NotFound':
+ pass
+ else:
+ raise
logger.info('... Succesfully deleted security group')
else:
assert len(instances) > len(instancesToTerminate)
@@ -490,12 +511,14 @@ class AWSProvisioner(AbstractProvisioner):
sshKey=self.masterPublicKey,
args=workerArgs.format(ip=self.leaderIP, preemptable=preemptable, keyPath=keyPath))
userData = awsUserData.format(**workerData)
+ sg_ids = [sg.id for sg in self.ctx.ec2.get_all_security_groups() if sg.name == self.clusterName]
kwargs = {'key_name': self.keyName,
- 'security_groups': [self.clusterName],
+ 'security_group_ids': sg_ids,
'instance_type': self.instanceType.name,
'user_data': userData,
'block_device_map': bdm,
'instance_profile_arn': arn}
+ kwargs["subnet_id"] = self._getClusterInstance(self.instanceMetaData).subnet_id
instancesLaunched = []
@@ -581,14 +604,19 @@ class AWSProvisioner(AbstractProvisioner):
return [request for request in requests if request.id in idsToCancel]
@classmethod
- def _createSecurityGroup(cls, ctx, name):
+ def _createSecurityGroup(cls, ctx, name, vpcSubnet=None):
def groupNotFound(e):
retry = (e.status == 400 and 'does not exist in default VPC' in e.body)
return retry
-
+ vpc_id = None
+ if vpcSubnet:
+ vpc_conn = boto.connect_vpc(region=ctx.ec2.region)
+ subnets = vpc_conn.get_all_subnets(subnet_ids=[vpcSubnet])
+ if len(subnets) > 0:
+ vpc_id = subnets[0].vpc_id
# security group create/get. ssh + all ports open within the group
try:
- web = ctx.ec2.create_security_group(name, 'Toil appliance security group')
+ web = ctx.ec2.create_security_group(name, 'Toil appliance security group', vpc_id=vpc_id)
except EC2ResponseError as e:
if e.status == 400 and 'already exists' in e.body:
pass # group exists- nothing to do
@@ -603,6 +631,11 @@ class AWSProvisioner(AbstractProvisioner):
with attempt:
# the following authorizes all port access within the web security group
web.authorize(ip_protocol='tcp', from_port=0, to_port=65535, src_group=web)
+ out_sgs = []
+ for sg in ctx.ec2.get_all_security_groups():
+ if vpc_id is None or sg.vpc_id == vpc_id:
+ out_sgs.append(sg)
+ return out_sgs
@classmethod
def _getProfileARN(cls, ctx):
diff --git a/src/toil/utils/toilLaunchCluster.py b/src/toil/utils/toilLaunchCluster.py
index ef7c254..73f5083 100644
--- a/src/toil/utils/toilLaunchCluster.py
+++ b/src/toil/utils/toilLaunchCluster.py
@@ -49,6 +49,9 @@ def main():
" \"Name\": clusterName,"
" \"Owner\": IAM username"
" }. ")
+ parser.add_argument("--vpcSubnet",
+ help="VPC subnet ID to launch cluster in. Uses default subnet if not specified."
+ "This subnet needs to have auto assign IPs turned on.")
config = parseBasicOptions(parser)
setLoggingFromOptions(config)
tagsDict = None if config.tags is None else createTagsDict(config.tags)
@@ -70,4 +73,5 @@ def main():
assert False
provisioner.launchCluster(instanceType=config.nodeType, clusterName=config.clusterName,
- keyName=config.keyPairName, spotBid=spotBid, userTags=tagsDict, zone=config.zone)
+ keyName=config.keyPairName, spotBid=spotBid, userTags=tagsDict, zone=config.zone,
+ vpcSubnet=config.vpcSubnet)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment