Skip to content

Instantly share code, notes, and snippets.

@walterheck
Created December 27, 2018 12:23
Show Gist options
  • Save walterheck/50f58fb411f5b02afdc179e372a657df to your computer and use it in GitHub Desktop.
Save walterheck/50f58fb411f5b02afdc179e372a657df to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import time
import datetime
import sys
import boto3
def main(db_id, source_profile, target_profile, target_instance_name,
snapshot_style, target_dbsubnet_group, target_security_group, overwrite_target,
rename_target):
# set up the source session with the specified profile
source_session = boto3.session.Session(profile_name=source_profile)
# retrieve the account number of the source account
source_account = source_session.client('sts').get_caller_identity().get('Account')
# start an RDS session for the source
source_rds_client = source_session.client('rds')
# create a customer managed kms key with a grant that lets it be used by the target profile
target_session = boto3.session.Session(profile_name=target_profile)
# start an RDS session for the target
target_rds_client = target_session.client('rds')
# get the account id of the target
target_aws_account = target_session.client('sts').get_caller_identity().get('Account')
# do the nasty to turn `arn:aws:sts::876521781234:assumed-role/SomeRole/botocore-session-1544534123` into an IAM arn
target_account_arn = target_session.client('sts').get_caller_identity().get('Arn').split('/boto')[0].replace(':sts::', ':iam::', 1).replace('assumed-role', 'role', 1)
# in order to not have dozens of similar keys, stick to this pattern so we can reuse keys
key_alias = 'alias/RDSBackupRestoreSharedKeyWith{}'.format(target_aws_account)
kms_key = get_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias)
# Automated Amazon RDS snapshots cannot be shared with other AWS accounts.
# To share an automated snapshot, copy the snapshot to make a manual version,
# and then share the copy.
# Additionally the copy needs to be re-encrypted with the Customer Managed KMS key
if snapshot_style == 'running_instance':
snapshot = make_snapshot_from_running_instance(source_rds_client, db_id)
wait_for_snapshot_to_be_ready(source_rds_client, snapshot)
elif snapshot_style == 'latest_snapshot':
# get latest snapshot from an AWS account with a given tag
snapshot = get_latest_automatic_rds_snapshots(source_rds_client, db_id)
else:
raise ValueError('snapshot_style has to be running_instance or latest_snapshot, but value {} found'.format(snapshot_style))
recrypted_copy = recrypt_snapshot_with_new_key(source_rds_client, snapshot, kms_key)
wait_for_snapshot_to_be_ready(source_rds_client, recrypted_copy)
share_snapshot_with_external_account(source_rds_client, recrypted_copy, target_aws_account)
# an encrypted shared snapshot owned by another account cannot be restored straight up
# so make a local copy in the target environment first
target_copy = copy_shared_snapshot_to_local(target_rds_client, recrypted_copy, kms_key)
wait_for_snapshot_to_be_ready(target_rds_client, target_copy)
rename_or_delete_target_instance(target_rds_client, target_instance_name, overwrite_target, rename_target)
target_instance = create_rds_instance_from_snapshot(rds_client=target_rds_client,
snapshot=target_copy,
instancename=target_instance_name,
dbsubnet_group=target_dbsubnet_group)
wait_for_instance_to_be_ready(target_rds_client, target_instance)
modify_rds_instance_security_groups(rds_client=target_rds_client, instancename=target_instance_name, securitygroup=target_security_group)
print(" Finished, check instance {}!".format(target_instance_name))
def get_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias):
kms = source_session.client('kms')
print("Searching for Customer Managed KMS Key with alias {} that is already shared with account {}...".format(key_alias, target_aws_account))
# try to retrieve the KMS key with the specified alias to see if it exists
try:
key = kms.describe_key(KeyId=key_alias)
print(" Found key: {}".format(key['KeyMetadata']['Arn']))
return key
except kms.exceptions.NotFoundException:
# if it doesn't exist, create it
print(" No valid key found.")
key = create_shared_kms_key(source_session, source_account, target_aws_account, target_account_arn, key_alias)
return key
def copy_shared_snapshot_to_local(rds_client, shared_snapshot, kms_key):
# unfortunately it's not possible to restore an RDS instance directly from a
# snapshot that is shared by another account. This makes a copy local to the
# account where we want to restore the RDS instance
taregt_db_snapshot_id = "{}-copy".format(shared_snapshot['DBSnapshotIdentifier'])
print("Copying shared snaphot {} to local snapshot {}...".format(shared_snapshot['DBSnapshotArn'], taregt_db_snapshot_id))
try:
copy = rds_client.copy_db_snapshot(
SourceDBSnapshotIdentifier=shared_snapshot['DBSnapshotArn'],
TargetDBSnapshotIdentifier=taregt_db_snapshot_id,
KmsKeyId=kms_key['KeyMetadata']['Arn']
)
print(" Copy created.")
return copy['DBSnapshot']
except rds_client.exceptions.DBSnapshotAlreadyExistsFault:
# if the snapshot we tried to make already exists, retrieve it
print("Snapshot already exists, retrieving {}...".format(taregt_db_snapshot_id))
snapshots = rds_client.describe_db_snapshots(
DBSnapshotIdentifier=taregt_db_snapshot_id,
)
print(" Retrieved.")
return snapshots['DBSnapshots'][0]
def create_shared_kms_key(session, source_account, target_account, target_account_arn, key_alias):
kms = session.client('kms')
print("Creating Customer Managed KMS Key that is shared...")
# create a Customer Managed KMS key, needed to be able to share the encrypted snapshot
kms_key = kms.create_key(
Description="Shared encryption key with AWS account {}".format(target_account_arn),
Policy="""{
"Version": "2012-10-17",
"Id": "key-default-1",
"Statement": [
{
"Sid": "Enable IAM User Permissions",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::%s:root"
},
"Action": "kms:*",
"Resource": "*"
},
{
"Sid": "Allow use of the key by the %s",
"Effect": "Allow",
"Principal": {
"AWS": "%s"
},
"Action": "kms:*",
"Resource": "*"
}
]
}
""" % (source_account, target_account, target_account_arn)
)
# add an alias to the key so we can later more easily determine if the key
# already exists without having to know it's keyid
kms.create_alias(
AliasName=key_alias,
TargetKeyId=kms_key['KeyMetadata']['Arn']
)
print(" Created KMS Key {}, shared with account {}".format(kms_key['KeyMetadata']['Arn'], target_account_arn))
return kms_key
def share_snapshot_with_external_account(rds_client, snapshot, target_account):
# in order to restore a snapshot from another account it needs to be shared
# with that account first
print("Modifying snaphot {} to be shared with account {}...".format(snapshot['DBSnapshotArn'], target_account))
rds_client.modify_db_snapshot_attribute(
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'],
AttributeName='restore',
ValuesToAdd=[target_account]
)
print(" Modified.")
def wait_for_instance_to_be_ready(rds_client, instance):
# simply check if the specified instance is healthy every 5 seconds until it
# is
while True:
instancecheck = rds_client.describe_db_instances(DBInstanceIdentifier=instance['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0]
if instancecheck['DBInstanceStatus'] == 'available':
print(" Instance {} ready and available!".format(instance['DBInstance']['DBInstanceIdentifier']))
break
else:
print("Instance creation in progress, sleeping 10 seconds...")
time.sleep(10)
def wait_for_snapshot_to_be_ready(rds_client, snapshot):
# simply check if the specified snapshot is healthy every 5 seconds until it
# is
while True:
snapshotcheck = rds_client.describe_db_snapshots(DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'])['DBSnapshots'][0]
if snapshotcheck['Status'] == 'available':
print(" Snapshot {} complete and available!".format(snapshot['DBSnapshotIdentifier']))
break
else:
print("Snapshot {} in progress, {}% complete".format(snapshot['DBSnapshotIdentifier'], snapshotcheck['PercentProgress']))
time.sleep(10)
def make_snapshot_from_running_instance(rds_client, db_id):
print("Making a new snapshot from the running RDS instance")
try:
today = datetime.date.today()
snapshot = rds_client.create_db_snapshot(
DBInstanceIdentifier=db_id,
DBSnapshotIdentifier="{}-{:%Y-%m-%d}".format(db_id, today),
)
print(" Snapshot created.")
return snapshot['DBSnapshot']
except Exception as exception:
print("ERROR: Failed to make snapshot from instance")
print(exception)
sys.exit(1)
def recrypt_snapshot_with_new_key(rds_client, snapshot, new_kms_key):
# create an identifier to use as the name of the manual snapshot copy
if ':' in snapshot['DBSnapshotIdentifier']:
target_db_snapshot_id = "{}-recrypted".format(snapshot['DBSnapshotIdentifier'].split(':')[1])
else:
target_db_snapshot_id = "{}-recrypted".format(snapshot['DBSnapshotIdentifier'])
print("Copying automatic snapshot to manual snapshot...")
try:
# copy the snapshot, supplying the new KMS key (which is also shared with
# the target account)
copy = rds_client.copy_db_snapshot(
SourceDBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'],
TargetDBSnapshotIdentifier=target_db_snapshot_id,
KmsKeyId=new_kms_key['KeyMetadata']['Arn']
)
print(" Snapshot created.")
return copy['DBSnapshot']
except rds_client.exceptions.DBSnapshotAlreadyExistsFault:
# if the snapshot we tried to make already exists, retrieve it
print("Snapshot already exists, retrieving {}".format(target_db_snapshot_id))
snapshots = rds_client.describe_db_snapshots(
DBSnapshotIdentifier=target_db_snapshot_id,
)
return snapshots['DBSnapshots'][0]
def rename_or_delete_target_instance(rds_client, instancename, overwrite_target, rename_target):
print("Checking for an existing RDS instance by the name {} and renaming or deleting if it's found".format(instancename))
# check if we already have an instance by this name
try:
instance = rds_client.describe_db_instances(DBInstanceIdentifier=instancename)['DBInstances'][0]
print(" Instance found")
except rds_client.exceptions.DBInstanceNotFoundFault:
instance = None
print(" Instance not found")
if instance is not None:
if overwrite_target:
print(" Instance found and overwrite if found True, deleting instance")
rds_client.delete_db_instance(
DBInstanceIdentifier=instancename,
SkipFinalSnapshot=True
)
print(" Deleting instance. This will take a while...")
waiter = rds_client.get_waiter('db_instance_deleted')
waiter.wait(
DBInstanceIdentifier=instancename,
WaiterConfig={
'MaxAttempts': 120
}
)
print(" Instance is deleted!")
elif rename_target:
print(" Instance found and rename if found True, renaming instance")
try:
rds_client.modify_db_instance(
DBInstanceIdentifier=instancename,
NewDBInstanceIdentifier="{}-old".format(instancename),
ApplyImmediately=True
)
except:
raise
def modify_rds_instance_security_groups(rds_client, instancename, securitygroup):
if securitygroup is not None:
print("Modifying RDS instance to attach correct securitygroup")
try:
rds_client.modify_db_instance(
DBInstanceIdentifier=instancename,
VpcSecurityGroupIds=[
securitygroup
],
ApplyImmediately=True
)
print(" RDS Instance {} modified".format(instancename))
except Exception as e:
raise
def create_rds_instance_from_snapshot(rds_client, snapshot, instancename, dbsubnet_group):
# restore an instance from the specified snapshot
print("Restoring RDS instance {} from snapshot {}".format(instancename, snapshot['DBSnapshotIdentifier']))
try:
if dbsubnet_group is None:
dbsubnet_group = 'default'
instance = rds_client.restore_db_instance_from_db_snapshot(
DBInstanceIdentifier=instancename,
DBSnapshotIdentifier=snapshot['DBSnapshotArn'],
DBSubnetGroupName=dbsubnet_group,
)
print(" RDS instance restored.")
return instance
except rds_client.exceptions.DBInstanceAlreadyExistsFault:
print("ERROR: an instance with the name {} already exists, please specify a different name or remove that instance first".format(instancename))
sys.exit(1)
def get_latest_automatic_rds_snapshots(rds_client, db_id):
print("Getting latest (automated) snapshot from rds instance {}...".format(db_id))
# we can't query for the latest snapshot straight away, so we have to retrieve
# a full list and go through all of them
snapshots = rds_client.describe_db_snapshots(
DBInstanceIdentifier=db_id,
SnapshotType='automated'
)
latest = 0
for snapshot in snapshots['DBSnapshots']:
if latest == 0:
latest = snapshot
if snapshot['SnapshotCreateTime'] > latest['SnapshotCreateTime']:
latest = snapshot
print(" Found snapshot {}".format(latest['DBSnapshotIdentifier']))
return latest
def str2bool(somestr):
if somestr.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif somestr.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="This is a simple utility to bring up an rds instance based on a snapshot")
parser.add_argument('--db-instance-identifier', dest='db_id', action='store', required=True, help='The identifier of the rds instance we want to grab the latest snapshot from')
parser.add_argument('--source-profile', dest='source_profile', action='store', required=True, help='The name of the AWS profile we need to use to look up the rds instance')
parser.add_argument('--target-profile', dest='target_profile', action='store', required=True, help='The name of the AWS profile we need to use to restore the snapshot to')
parser.add_argument('--target-instance-name', dest='target_instance', action='store', required=True, help='The name of the rds instance we restore the snapshot to')
parser.add_argument('--snapshot-style', dest='snapshot_style', action='store', choices=['running_instance', 'latest_snapshot'], required=True, help='Instead of using the snapshot of the last available backup, make a fresh snapshot')
parser.add_argument('--overwrite-target-if-exists', dest='overwrite_target', type=str2bool, choices=[True, False], default=False, help='Wether or not to remove the target instance and overwrite it with the backup if it exists')
parser.add_argument('--rename-target-if-exists', dest='rename_target', type=str2bool, choices=[True, False], default=False, help='Wether or not to rename the target instance and overwrite it with the backup if it exists')
parser.add_argument('--target-dbsubnet-group', dest='target_dbsubnet_group', action='store', required=True, help='The name of the dbsubnet group we need to use to restore the snapshot to')
parser.add_argument('--target-security-group', dest='target_security_group', action='store', required=True, help='The id of the security group we need to use to restore the snapshot to')
args = parser.parse_args()
print("supplied arguments:")
print(args)
print("")
main(
db_id=args.db_id,
source_profile=args.source_profile,
target_profile=args.target_profile,
target_dbsubnet_group=args.target_dbsubnet_group,
target_security_group=args.target_security_group,
target_instance_name=args.target_instance,
snapshot_style=args.snapshot_style,
overwrite_target=args.overwrite_target,
rename_target=args.rename_target
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment