Created
December 27, 2018 12:23
-
-
Save walterheck/50f58fb411f5b02afdc179e372a657df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import time | |
import datetime | |
import sys | |
import boto3 | |
def main(db_id, source_profile, target_profile, target_instance_name, | |
snapshot_style, target_dbsubnet_group, target_security_group, overwrite_target, | |
rename_target): | |
# set up the source session with the specified profile | |
source_session = boto3.session.Session(profile_name=source_profile) | |
# retrieve the account number of the source account | |
source_account = source_session.client('sts').get_caller_identity().get('Account') | |
# start an RDS session for the source | |
source_rds_client = source_session.client('rds') | |
# create a customer managed kms key with a grant that lets it be used by the target profile | |
target_session = boto3.session.Session(profile_name=target_profile) | |
# start an RDS session for the target | |
target_rds_client = target_session.client('rds') | |
# get the account id of the target | |
target_aws_account = target_session.client('sts').get_caller_identity().get('Account') | |
# do the nasty to turn `arn:aws:sts::876521781234:assumed-role/SomeRole/botocore-session-1544534123` into an IAM arn | |
target_account_arn = target_session.client('sts').get_caller_identity().get('Arn').split('/boto')[0].replace(':sts::', ':iam::', 1).replace('assumed-role', 'role', 1) | |
# in order to not have dozens of similar keys, stick to this pattern so we can reuse keys | |
key_alias = 'alias/RDSBackupRestoreSharedKeyWith{}'.format(target_aws_account) | |
kms_key = get_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias) | |
# Automated Amazon RDS snapshots cannot be shared with other AWS accounts. | |
# To share an automated snapshot, copy the snapshot to make a manual version, | |
# and then share the copy. | |
# Additionally the copy needs to be re-encrypted with the Customer Managed KMS key | |
if snapshot_style == 'running_instance': | |
snapshot = make_snapshot_from_running_instance(source_rds_client, db_id) | |
wait_for_snapshot_to_be_ready(source_rds_client, snapshot) | |
elif snapshot_style == 'latest_snapshot': | |
# get latest snapshot from an AWS account with a given tag | |
snapshot = get_latest_automatic_rds_snapshots(source_rds_client, db_id) | |
else: | |
raise ValueError('snapshot_style has to be running_instance or latest_snapshot, but value {} found'.format(snapshot_style)) | |
recrypted_copy = recrypt_snapshot_with_new_key(source_rds_client, snapshot, kms_key) | |
wait_for_snapshot_to_be_ready(source_rds_client, recrypted_copy) | |
share_snapshot_with_external_account(source_rds_client, recrypted_copy, target_aws_account) | |
# an encrypted shared snapshot owned by another account cannot be restored straight up | |
# so make a local copy in the target environment first | |
target_copy = copy_shared_snapshot_to_local(target_rds_client, recrypted_copy, kms_key) | |
wait_for_snapshot_to_be_ready(target_rds_client, target_copy) | |
rename_or_delete_target_instance(target_rds_client, target_instance_name, overwrite_target, rename_target) | |
target_instance = create_rds_instance_from_snapshot(rds_client=target_rds_client, | |
snapshot=target_copy, | |
instancename=target_instance_name, | |
dbsubnet_group=target_dbsubnet_group) | |
wait_for_instance_to_be_ready(target_rds_client, target_instance) | |
modify_rds_instance_security_groups(rds_client=target_rds_client, instancename=target_instance_name, securitygroup=target_security_group) | |
print(" Finished, check instance {}!".format(target_instance_name)) | |
def get_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias): | |
kms = source_session.client('kms') | |
print("Searching for Customer Managed KMS Key with alias {} that is already shared with account {}...".format(key_alias, target_aws_account)) | |
# try to retrieve the KMS key with the specified alias to see if it exists | |
try: | |
key = kms.describe_key(KeyId=key_alias) | |
print(" Found key: {}".format(key['KeyMetadata']['Arn'])) | |
return key | |
except kms.exceptions.NotFoundException: | |
# if it doesn't exist, create it | |
print(" No valid key found.") | |
key = create_shared_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias) | |
return key | |
def copy_shared_snapshot_to_local(rds_client, shared_snapshot, kms_key): | |
# unfortunately it's not possible to restore an RDS instance directly from a | |
# snapshot that is shared by another account. This makes a copy local to the | |
# account where we want to restore the RDS instance | |
taregt_db_snapshot_id = "{}-copy".format(shared_snapshot['DBSnapshotIdentifier']) | |
print("Copying shared snaphot {} to local snapshot {}...".format(shared_snapshot['DBSnapshotArn'], taregt_db_snapshot_id)) | |
try: | |
copy = rds_client.copy_db_snapshot( | |
SourceDBSnapshotIdentifier=shared_snapshot['DBSnapshotArn'], | |
TargetDBSnapshotIdentifier=taregt_db_snapshot_id, | |
KmsKeyId=kms_key['KeyMetadata']['Arn'] | |
) | |
print(" Copy created.") | |
return copy['DBSnapshot'] | |
except rds_client.exceptions.DBSnapshotAlreadyExistsFault: | |
# if the snapshot we tried to make already exists, retrieve it | |
print("Snapshot already exists, retrieving {}...".format(taregt_db_snapshot_id)) | |
snapshots = rds_client.describe_db_snapshots( | |
DBSnapshotIdentifier=taregt_db_snapshot_id, | |
) | |
print(" Retrieved.") | |
return snapshots['DBSnapshots'][0] | |
def create_shared_kms_key(session, source_account, target_account, target_account_arn, key_alias): | |
kms = session.client('kms') | |
print("Creating Customer Managed KMS Key that is shared...") | |
# create a Customer Managed KMS key, needed to be able to share the encrypted snapshot | |
kms_key = kms.create_key( | |
Description="Shared encryption key with AWS account {}".format(target_account_arn), | |
Policy="""{ | |
"Version": "2012-10-17", | |
"Id": "key-default-1", | |
"Statement": [ | |
{ | |
"Sid": "Enable IAM User Permissions", | |
"Effect": "Allow", | |
"Principal": { | |
"AWS": "arn:aws:iam::%s:root" | |
}, | |
"Action": "kms:*", | |
"Resource": "*" | |
}, | |
{ | |
"Sid": "Allow use of the key by the %s", | |
"Effect": "Allow", | |
"Principal": { | |
"AWS": "%s" | |
}, | |
"Action": "kms:*", | |
"Resource": "*" | |
} | |
] | |
} | |
""" % (source_account, target_account, target_account_arn) | |
) | |
# add an alias to the key so we can later more easily determine if the key | |
# already exists without having to know it's keyid | |
kms.create_alias( | |
AliasName=key_alias, | |
TargetKeyId=kms_key['KeyMetadata']['Arn'] | |
) | |
print(" Created KMS Key {}, shared with account {}".format(kms_key['KeyMetadata']['Arn'], target_account_arn)) | |
return kms_key | |
def share_snapshot_with_external_account(rds_client, snapshot, target_account): | |
# in order to restore a snapshot from another account it needs to be shared | |
# with that account first | |
print("Modifying snaphot {} to be shared with account {}...".format(snapshot['DBSnapshotArn'], target_account)) | |
rds_client.modify_db_snapshot_attribute( | |
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'], | |
AttributeName='restore', | |
ValuesToAdd=[target_account] | |
) | |
print(" Modified.") | |
def wait_for_instance_to_be_ready(rds_client, instance): | |
# simply check if the specified instance is healthy every 5 seconds until it | |
# is | |
while True: | |
instancecheck = rds_client.describe_db_instances(DBInstanceIdentifier=instance['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] | |
if instancecheck['DBInstanceStatus'] == 'available': | |
print(" Instance {} ready and available!".format(instance['DBInstance']['DBInstanceIdentifier'])) | |
break | |
else: | |
print("Instance creation in progress, sleeping 10 seconds...") | |
time.sleep(10) | |
def wait_for_snapshot_to_be_ready(rds_client, snapshot): | |
# simply check if the specified snapshot is healthy every 5 seconds until it | |
# is | |
while True: | |
snapshotcheck = rds_client.describe_db_snapshots(DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'])['DBSnapshots'][0] | |
if snapshotcheck['Status'] == 'available': | |
print(" Snapshot {} complete and available!".format(snapshot['DBSnapshotIdentifier'])) | |
break | |
else: | |
print("Snapshot {} in progress, {}% complete".format(snapshot['DBSnapshotIdentifier'], snapshotcheck['PercentProgress'])) | |
time.sleep(10) | |
def make_snapshot_from_running_instance(rds_client, db_id): | |
print("Making a new snapshot from the running RDS instance") | |
try: | |
today = datetime.date.today() | |
snapshot = rds_client.create_db_snapshot( | |
DBInstanceIdentifier=db_id, | |
DBSnapshotIdentifier="{}-{:%Y-%m-%d}".format(db_id, today), | |
) | |
print(" Snapshot created.") | |
return snapshot['DBSnapshot'] | |
except Exception as exception: | |
print("ERROR: Failed to make snapshot from instance") | |
print(exception) | |
sys.exit(1) | |
def recrypt_snapshot_with_new_key(rds_client, snapshot, new_kms_key): | |
# create an identifier to use as the name of the manual snapshot copy | |
if ':' in snapshot['DBSnapshotIdentifier']: | |
target_db_snapshot_id = "{}-recrypted".format(snapshot['DBSnapshotIdentifier'].split(':')[1]) | |
else: | |
target_db_snapshot_id = "{}-recrypted".format(snapshot['DBSnapshotIdentifier']) | |
print("Copying automatic snapshot to manual snapshot...") | |
try: | |
# copy the snapshot, supplying the new KMS key (which is also shared with | |
# the target account) | |
copy = rds_client.copy_db_snapshot( | |
SourceDBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'], | |
TargetDBSnapshotIdentifier=target_db_snapshot_id, | |
KmsKeyId=new_kms_key['KeyMetadata']['Arn'] | |
) | |
print(" Snapshot created.") | |
return copy['DBSnapshot'] | |
except rds_client.exceptions.DBSnapshotAlreadyExistsFault: | |
# if the snapshot we tried to make already exists, retrieve it | |
print("Snapshot already exists, retrieving {}".format(target_db_snapshot_id)) | |
snapshots = rds_client.describe_db_snapshots( | |
DBSnapshotIdentifier=target_db_snapshot_id, | |
) | |
return snapshots['DBSnapshots'][0] | |
def rename_or_delete_target_instance(rds_client, instancename, overwrite_target, rename_target): | |
print("Checking for an existing RDS instance by the name {} and renaming or deleting if it's found".format(instancename)) | |
# check if we already have an instance by this name | |
try: | |
instance = rds_client.describe_db_instances(DBInstanceIdentifier=instancename)['DBInstances'][0] | |
print(" Instance found") | |
except rds_client.exceptions.DBInstanceNotFoundFault: | |
instance = None | |
print(" Instance not found") | |
if instance is not None: | |
if overwrite_target: | |
print(" Instance found and overwrite if found True, deleting instance") | |
rds_client.delete_db_instance( | |
DBInstanceIdentifier=instancename, | |
SkipFinalSnapshot=True | |
) | |
print(" Deleting instance. This will take a while...") | |
waiter = rds_client.get_waiter('db_instance_deleted') | |
waiter.wait( | |
DBInstanceIdentifier=instancename, | |
WaiterConfig={ | |
'MaxAttempts': 120 | |
} | |
) | |
print(" Instance is deleted!") | |
elif rename_target: | |
print(" Instance found and rename if found True, renaming instance") | |
try: | |
rds_client.modify_db_instance( | |
DBInstanceIdentifier=instancename, | |
NewDBInstanceIdentifier="{}-old".format(instancename), | |
ApplyImmediately=True | |
) | |
except: | |
raise | |
def modify_rds_instance_security_groups(rds_client, instancename, securitygroup): | |
if securitygroup is not None: | |
print("Modifying RDS instance to attach correct securitygroup") | |
try: | |
rds_client.modify_db_instance( | |
DBInstanceIdentifier=instancename, | |
VpcSecurityGroupIds=[ | |
securitygroup | |
], | |
ApplyImmediately=True | |
) | |
print(" RDS Instance {} modified".format(instancename)) | |
except Exception as e: | |
raise | |
def create_rds_instance_from_snapshot(rds_client, snapshot, instancename, dbsubnet_group): | |
# restore an instance from the specified snapshot | |
print("Restoring RDS instance {} from snapshot {}".format(instancename, snapshot['DBSnapshotIdentifier'])) | |
try: | |
if dbsubnet_group is None: | |
dbsubnet_group = 'default' | |
instance = rds_client.restore_db_instance_from_db_snapshot( | |
DBInstanceIdentifier=instancename, | |
DBSnapshotIdentifier=snapshot['DBSnapshotArn'], | |
DBSubnetGroupName=dbsubnet_group, | |
) | |
print(" RDS instance restored.") | |
return instance | |
except rds_client.exceptions.DBInstanceAlreadyExistsFault: | |
print("ERROR: an instance with the name {} already exists, please specify a different name or remove that instance first".format(instancename)) | |
sys.exit(1) | |
def get_latest_automatic_rds_snapshots(rds_client, db_id): | |
print("Getting latest (automated) snapshot from rds instance {}...".format(db_id)) | |
# we can't query for the latest snapshot straight away, so we have to retrieve | |
# a full list and go through all of them | |
snapshots = rds_client.describe_db_snapshots( | |
DBInstanceIdentifier=db_id, | |
SnapshotType='automated' | |
) | |
latest = 0 | |
for snapshot in snapshots['DBSnapshots']: | |
if latest == 0: | |
latest = snapshot | |
if snapshot['SnapshotCreateTime'] > latest['SnapshotCreateTime']: | |
latest = snapshot | |
print(" Found snapshot {}".format(latest['DBSnapshotIdentifier'])) | |
return latest | |
def str2bool(somestr): | |
if somestr.lower() in ('yes', 'true', 't', 'y', '1'): | |
return True | |
elif somestr.lower() in ('no', 'false', 'f', 'n', '0'): | |
return False | |
else: | |
raise argparse.ArgumentTypeError('Boolean value expected.') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="This is a simple utility to bring up an rds instance based on a snapshot") | |
parser.add_argument('--db-instance-identifier', dest='db_id', action='store', required=True, help='The identifier of the rds instance we want to grab the latest snapshot from') | |
parser.add_argument('--source-profile', dest='source_profile', action='store', required=True, help='The name of the AWS profile we need to use to look up the rds instance') | |
parser.add_argument('--target-profile', dest='target_profile', action='store', required=True, help='The name of the AWS profile we need to use to restore the snapshot to') | |
parser.add_argument('--target-instance-name', dest='target_instance', action='store', required=True, help='The name of the rds instance we restore the snapshot to') | |
parser.add_argument('--snapshot-style', dest='snapshot_style', action='store', choices=['running_instance', 'latest_snapshot'], required=True, help='Instead of using the snapshot of the last available backup, make a fresh snapshot') | |
parser.add_argument('--overwrite-target-if-exists', dest='overwrite_target', type=str2bool, choices=[True, False], default=False, help='Wether or not to remove the target instance and overwrite it with the backup if it exists') | |
parser.add_argument('--rename-target-if-exists', dest='rename_target', type=str2bool, choices=[True, False], default=False, help='Wether or not to rename the target instance and overwrite it with the backup if it exists') | |
parser.add_argument('--target-dbsubnet-group', dest='target_dbsubnet_group', action='store', required=True, help='The name of the dbsubnet group we need to use to restore the snapshot to') | |
parser.add_argument('--target-security-group', dest='target_security_group', action='store', required=True, help='The id of the security group we need to use to restore the snapshot to') | |
args = parser.parse_args() | |
print("supplied arguments:") | |
print(args) | |
print("") | |
main( | |
db_id=args.db_id, | |
source_profile=args.source_profile, | |
target_profile=args.target_profile, | |
target_dbsubnet_group=args.target_dbsubnet_group, | |
target_security_group=args.target_security_group, | |
target_instance_name=args.target_instance, | |
snapshot_style=args.snapshot_style, | |
overwrite_target=args.overwrite_target, | |
rename_target=args.rename_target | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment