-
-
Save vmoravec/056d7e1e0ad93f51f36bea64686c948a to your computer and use it in GitHub Desktop.
Lambda function for draining ECS instances before terminating it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import boto3 | |
import base64 | |
import json | |
import logging | |
logging.basicConfig() | |
logger = logging.getLogger() | |
logger.setLevel(logging.DEBUG) | |
# Establish boto3 session | |
session = boto3.session.Session() | |
logger.debug("Session is in region %s ", session.region_name) | |
ec2Client = session.client(service_name='ec2') | |
ecsClient = session.client(service_name='ecs') | |
asgClient = session.client('autoscaling') | |
snsClient = session.client('sns') | |
lambdaClient = session.client('lambda') | |
def publishToSNS(message, topicARN): | |
"""Publish SNS message to trigger lambda again. | |
:param message: To repost the complete original message received when ASG terminating event was received. | |
:param topicARN: SNS topic to publish the message to. | |
""" | |
logger.info("Publish to SNS topic %s", topicARN) | |
snsClient.publish( | |
TopicArn=topicARN, | |
Message=json.dumps(message), | |
Subject='Publishing SNS message to invoke lambda again..' | |
) | |
def checkContainerInstanceTaskStatus(Ec2InstanceId): | |
"""Check task status on the ECS container instance ID. | |
:param Ec2InstanceId: The EC2 instance ID is used to identify the cluster, container instances in cluster | |
""" | |
containerInstanceId = None | |
clusterName = None | |
tmpMsgAppend = None | |
# Describe instance attributes and get the Clustername from userdata section which would have set ECS_CLUSTER name | |
ec2Resp = ec2Client.describe_instance_attribute(InstanceId=Ec2InstanceId, Attribute='userData') | |
userdataEncoded = ec2Resp['UserData'] | |
userdataDecoded = base64.b64decode(userdataEncoded['Value']) | |
logger.debug("Describe instance attributes response %s", ec2Resp) | |
tmpList = userdataDecoded.split() | |
for token in tmpList: | |
if token.find("ECS_CLUSTER") > -1: | |
# Split and get the cluster name | |
clusterName = token.split('=')[1] | |
logger.info("Cluster name %s", clusterName) | |
# Get list of container instance IDs from the clusterName | |
paginator = ecsClient.get_paginator('list_container_instances') | |
clusterListPages = paginator.paginate(cluster=clusterName) | |
for containerListResp in clusterListPages: | |
containerDetResp = ecsClient.describe_container_instances( | |
cluster=clusterName, | |
containerInstances=containerListResp['containerInstanceArns'], | |
) | |
logger.debug("describe container instances response %s", containerDetResp) | |
for containerInstances in containerDetResp['containerInstances']: | |
logger.debug( | |
"Container Instance ARN: %s and ec2 Instance ID %s", | |
containerInstances['containerInstanceArn'], | |
containerInstances['ec2InstanceId'], | |
) | |
if containerInstances['ec2InstanceId'] == Ec2InstanceId: | |
logger.info("Container instance ID of interest : %s", | |
containerInstances['containerInstanceArn']) | |
containerInstanceId = containerInstances['containerInstanceArn'] | |
# Check if the instance state is set to DRAINING. If not, set it, so the ECS Cluster | |
# will handle de-registering instance, draining tasks and draining them | |
containerStatus = containerInstances['status'] | |
if containerStatus == 'DRAINING': | |
logger.info( | |
"Container ID %s with EC2 instance-id %s is draining tasks", | |
containerInstanceId, | |
Ec2InstanceId, | |
) | |
tmpMsgAppend = {"containerInstanceId": containerInstanceId} | |
else: | |
# Make ECS API call to set the container status to DRAINING | |
logger.info("Make ECS API call to set the container status to DRAINING...") | |
ecsClient.update_container_instances_state( | |
cluster=clusterName, | |
containerInstances=[containerInstanceId], | |
status='DRAINING', | |
) | |
# When you set instance state to draining, append the containerInstanceID to the message as well | |
tmpMsgAppend = {"containerInstanceId": containerInstanceId} | |
break | |
if containerInstanceId is not None: | |
break | |
# Using container Instance ID, get the task list, and task running on that instance. | |
if containerInstanceId is not None: | |
# List tasks on the container instance ID, to get task Arns | |
listTaskResp = ecsClient.list_tasks(cluster=clusterName, containerInstance=containerInstanceId) | |
logger.debug("Container instance task list %s", listTaskResp['taskArns']) | |
# If the chosen instance has tasks | |
if len(listTaskResp['taskArns']) > 0: | |
logger.info("Tasks are on this instance...%s", Ec2InstanceId) | |
return 1, tmpMsgAppend | |
else: | |
logger.info("NO tasks are on this instance...%s", Ec2InstanceId) | |
return 0, tmpMsgAppend | |
else: | |
logger.info("NO tasks are on this instance....%s", Ec2InstanceId) | |
return 0, tmpMsgAppend | |
def lambda_handler(event, context): | |
line = event['Records'][0]['Sns']['Message'] | |
message = json.loads(line) | |
if not message.get('EC2InstanceId'): | |
return | |
Ec2InstanceId = message['EC2InstanceId'] | |
asgGroupName = message['AutoScalingGroupName'] | |
snsArn = event['Records'][0]['EventSubscriptionArn'] | |
TopicArn = event['Records'][0]['Sns']['TopicArn'] | |
lifecycleHookName = None | |
clusterName = None | |
tmpMsgAppend = None | |
logger.info("Lambda received the event %s", event) | |
logger.debug("records: %s", event['Records'][0]) | |
logger.debug("sns: %s", event['Records'][0]['Sns']) | |
logger.debug("Message: %s", message) | |
logger.debug("Ec2 Instance Id %s ,%s", Ec2InstanceId, asgGroupName) | |
logger.debug("SNS ARN %s", snsArn) | |
# Describe instance attributes and get the Clustername from userdata section which would have set ECS_CLUSTER name | |
ec2Resp = ec2Client.describe_instance_attribute(InstanceId=Ec2InstanceId, Attribute='userData') | |
logger.debug("Describe instance attributes response %s", ec2Resp) | |
userdataEncoded = ec2Resp['UserData'] | |
userdataDecoded = base64.b64decode(userdataEncoded['Value']) | |
tmpList = userdataDecoded.split() | |
for token in tmpList: | |
if token.find("ECS_CLUSTER") > -1: | |
# Split and get the cluster name | |
clusterName = token.split('=')[1] | |
logger.debug("Cluster name %s", clusterName) | |
# If the event received is instance terminating... | |
if 'LifecycleTransition' in message.keys(): | |
logger.debug("message autoscaling %s", message['LifecycleTransition']) | |
if message['LifecycleTransition'].find('autoscaling:EC2_INSTANCE_TERMINATING') > -1: | |
# Get lifecycle hook name | |
lifecycleHookName = message['LifecycleHookName'] | |
logger.debug("Setting lifecycle hook name %s ", lifecycleHookName) | |
# Check if there are any tasks running on the instance | |
tasksRunning, tmpMsgAppend = checkContainerInstanceTaskStatus(Ec2InstanceId) | |
logger.debug("Returned values received: %s ", tasksRunning) | |
if tmpMsgAppend is not None: | |
message.update(tmpMsgAppend) | |
# If tasks are still running... | |
if tasksRunning == 1: | |
publishToSNS(message, TopicArn) | |
# If tasks are NOT running... | |
elif tasksRunning == 0: | |
logger.debug("Setting lifecycle to complete; No tasks are running on instance, completing lifecycle action....") | |
try: | |
response = asgClient.complete_lifecycle_action( | |
LifecycleHookName=lifecycleHookName, | |
AutoScalingGroupName=asgGroupName, | |
LifecycleActionResult='CONTINUE', | |
InstanceId=Ec2InstanceId) | |
logger.info("Response received from complete_lifecycle_action %s", response) | |
logger.info("Completedlifecycle hook action") | |
except Exception as e: | |
print(str(e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment