Last active
October 14, 2025 15:53
-
-
Save filipeandre/f4a3a9bbb5b927e1dd9d8466d1c149ac to your computer and use it in GitHub Desktop.
Re-encrypt an Amazon RDS *instance* by snapshot→copy(with KMS)→restore.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Re-encrypt an Amazon RDS *instance* by snapshot → copy(with KMS) → restore. | |
| Usage: | |
| python rds_rekey_instance.py \ | |
| --db-identifier my-db \ | |
| --target-kms-key-id arn:aws:kms:us-east-1:123456789012:key/abcd-... \ | |
| --region us-east-1 \ | |
| [--source-snapshot-id my-existing-snapshot] \ | |
| [--encrypted-snapshot-id my-custom-snap-enc] \ | |
| [--new-instance-identifier my-db-rekeyed] \ | |
| [--role-arn arn:aws:iam::123456789012:role/Admin] \ | |
| [--copy-tags] | |
| """ | |
| import argparse | |
| import datetime as dt | |
| import sys | |
| import time | |
| from typing import Dict, List, Optional | |
| import boto3 | |
| from botocore.exceptions import ClientError | |
| # ---------- Utility ---------- | |
| def get_boto3_session(region: str, role_arn: Optional[str] = None) -> boto3.Session: | |
| if role_arn: | |
| sts = boto3.client("sts", region_name=region) | |
| creds = sts.assume_role( | |
| RoleArn=role_arn, | |
| RoleSessionName=f"rds-rekey-{int(time.time())}" | |
| )["Credentials"] | |
| return boto3.Session( | |
| aws_access_key_id=creds["AccessKeyId"], | |
| aws_secret_access_key=creds["SecretAccessKey"], | |
| aws_session_token=creds["SessionToken"], | |
| region_name=region, | |
| ) | |
| return boto3.Session(region_name=region) | |
| def now_suffix() -> str: | |
| return dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S") | |
| def snapshot_exists(rds, snapshot_id: str) -> Optional[Dict]: | |
| """Return snapshot dict if it exists, else None.""" | |
| try: | |
| resp = rds.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id) | |
| snaps = resp.get("DBSnapshots", []) | |
| return snaps[0] if snaps else None | |
| except ClientError as e: | |
| code = e.response.get("Error", {}).get("Code") | |
| if code in ("DBSnapshotNotFound", "InvalidParameterValue"): | |
| return None | |
| raise | |
| def wait_for_snapshot(rds, snapshot_id: str, poll_sec: int = 30): | |
| waiter = rds.get_waiter("db_snapshot_available") | |
| print(f"[wait] Waiting for snapshot '{snapshot_id}' to become available...") | |
| waiter.wait(DBSnapshotIdentifier=snapshot_id, WaiterConfig={"Delay": poll_sec, "MaxAttempts": 240}) | |
| print(f"[ok] Snapshot '{snapshot_id}' is available.") | |
| def wait_for_instance_available(rds, db_instance_id: str, poll_sec: int = 30): | |
| waiter = rds.get_waiter("db_instance_available") | |
| print(f"[wait] Waiting for DB instance '{db_instance_id}' to become available...") | |
| waiter.wait(DBInstanceIdentifier=db_instance_id, WaiterConfig={"Delay": poll_sec, "MaxAttempts": 240}) | |
| print(f"[ok] DB instance '{db_instance_id}' is available.") | |
| def describe_instance(rds, db_id: str) -> Dict: | |
| resp = rds.describe_db_instances(DBInstanceIdentifier=db_id) | |
| return resp["DBInstances"][0] | |
| def ensure_not_aurora(db: Dict): | |
| engine = db["Engine"].lower() | |
| if engine.startswith("aurora"): | |
| raise SystemExit( | |
| f"[error] '{db['DBInstanceIdentifier']}' appears to be an Aurora/cluster engine ('{engine}'). " | |
| "Use the *cluster* snapshot APIs (CopyDBClusterSnapshot / RestoreDBClusterFromSnapshot) for Aurora." | |
| ) | |
| def build_network_and_config_args_from_source(db: Dict) -> Dict: | |
| """ | |
| Build restore args to align the new instance with the source's class, networking, | |
| parameter groups, etc. Only fields supported by restore_db_instance_from_db_snapshot. | |
| """ | |
| args: Dict = {} | |
| # Instance class / sizing | |
| if "DBInstanceClass" in db: | |
| args["DBInstanceClass"] = db["DBInstanceClass"] | |
| # Networking | |
| if db.get("DBSubnetGroup", {}).get("DBSubnetGroupName"): | |
| args["DBSubnetGroupName"] = db["DBSubnetGroup"]["DBSubnetGroupName"] | |
| vpc_sg_ids: List[str] = [ | |
| sg["VpcSecurityGroupId"] | |
| for sg in db.get("VpcSecurityGroups", []) | |
| if sg.get("VpcSecurityGroupId") | |
| ] | |
| if vpc_sg_ids: | |
| args["VpcSecurityGroupIds"] = vpc_sg_ids | |
| if "PubliclyAccessible" in db: | |
| args["PubliclyAccessible"] = db["PubliclyAccessible"] | |
| # Port (optional; often implied by engine but allowed) | |
| if db.get("Endpoint", {}).get("Port"): | |
| args["Port"] = db["Endpoint"]["Port"] | |
| # Storage class / IOPS (where applicable) | |
| if "StorageType" in db: | |
| args["StorageType"] = db["StorageType"] | |
| if db.get("Iops"): | |
| args["Iops"] = db["Iops"] | |
| # Multi-AZ preference | |
| if "MultiAZ" in db: | |
| args["MultiAZ"] = db["MultiAZ"] | |
| # Parameter & option groups | |
| if db.get("DBParameterGroups"): | |
| args["DBParameterGroupName"] = db["DBParameterGroups"][0]["DBParameterGroupName"] | |
| # Option groups are usually not used by PostgreSQL; include only if present/healthy. | |
| if db.get("OptionGroupMemberships"): | |
| ogs = [ | |
| og["OptionGroupName"] | |
| for og in db["OptionGroupMemberships"] | |
| if og.get("Status") in (None, "in-sync", "pending-apply") | |
| ] | |
| if ogs: | |
| args["OptionGroupName"] = ogs[0] | |
| # CloudWatch logs exports (if present on source) | |
| if db.get("EnabledCloudwatchLogsExports"): | |
| args["EnableCloudwatchLogsExports"] = db["EnabledCloudwatchLogsExports"] | |
| # Deletion protection, auto minor version upgrade (if present) | |
| if "AutoMinorVersionUpgrade" in db: | |
| args["AutoMinorVersionUpgrade"] = db["AutoMinorVersionUpgrade"] | |
| # IAM DB Auth (if enabled on source) | |
| if "IAMDatabaseAuthenticationEnabled" in db: | |
| args["EnableIAMDatabaseAuthentication"] = db["IAMDatabaseAuthenticationEnabled"] | |
| return args | |
| def print_sanity(restored: Dict): | |
| """Emit a concise checklist for manual verification.""" | |
| ident = restored["DBInstanceIdentifier"] | |
| engine = f"{restored.get('Engine')} {restored.get('EngineVersion')}" | |
| subnet = restored.get("DBSubnetGroup", {}).get("DBSubnetGroupName") | |
| sgs = [sg.get("VpcSecurityGroupId") for sg in restored.get("VpcSecurityGroups", [])] | |
| port = restored.get("Endpoint", {}).get("Port") | |
| pg = (restored.get("DBParameterGroups") or [{}])[0].get("DBParameterGroupName") | |
| ca = restored.get("CACertificateIdentifier") | |
| kms = restored.get("KmsKeyId") | |
| enc = restored.get("StorageEncrypted") | |
| publicly = restored.get("PubliclyAccessible") | |
| multi_az = restored.get("MultiAZ") | |
| print("\n[sanity]") | |
| print(f" InstanceId............... {ident}") | |
| print(f" Engine/Version........... {engine}") | |
| print(f" SubnetGroup.............. {subnet}") | |
| print(f" VPC SGs.................. {', '.join(filter(None, sgs)) or '-'}") | |
| print(f" Port..................... {port}") | |
| print(f" Param Group.............. {pg}") | |
| print(f" CA Cert.................. {ca or '-'}") | |
| print(f" Encrypted................ {enc} (KMS={kms})") | |
| print(f" PubliclyAccessible....... {publicly}") | |
| print(f" MultiAZ.................. {multi_az}\n") | |
| # ---------- Core workflow ---------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Re-encrypt an RDS DB instance by snapshot → copy(with KMS) → restore.") | |
| parser.add_argument("--db-identifier", required=True, help="Source DB instance identifier.") | |
| parser.add_argument("--target-kms-key-id", required=True, help="KMS Key ARN or ID for the *copied* snapshot.") | |
| parser.add_argument("--region", required=True, help="AWS region, e.g., eu-west-1.") | |
| parser.add_argument("--new-instance-identifier", help="Identifier for the restored instance.") | |
| parser.add_argument("--role-arn", help="Optional IAM role to assume.") | |
| parser.add_argument("--source-snapshot-id", help="Use existing snapshot ID instead of creating a new one.") | |
| parser.add_argument("--manual-snapshot-id", help="Custom ID for the manual snapshot to create/reuse.") | |
| parser.add_argument("--encrypted-snapshot-id", help="Custom ID for the encrypted snapshot copy to create/reuse.") | |
| parser.add_argument("--copy-tags", action="store_true", help="Copy tags to the encrypted snapshot copy.") | |
| args = parser.parse_args() | |
| session = get_boto3_session(region=args.region, role_arn=args.role_arn) | |
| rds = session.client("rds") | |
| # Get DB details | |
| print(f"[info] Describing source DB instance '{args.db_identifier}'...") | |
| try: | |
| source_db = describe_instance(rds, args.db_identifier) | |
| except ClientError as e: | |
| raise SystemExit(f"[error] Could not describe DB instance '{args.db_identifier}': {e}") | |
| ensure_not_aurora(source_db) | |
| ts = now_suffix() | |
| # 1) Decide the manual snapshot to use/create | |
| if args.source_snapshot_id: | |
| manual_snapshot_id = args.source_snapshot_id | |
| print(f"[info] Using existing snapshot '{manual_snapshot_id}'.") | |
| else: | |
| # Prefer user-provided name, else default pattern | |
| manual_snapshot_id = args.manual_snapshot_id or f"{args.db_identifier}-manual-{ts}" | |
| if snapshot_exists(rds, manual_snapshot_id): | |
| print(f"[info] Reusing pre-existing snapshot '{manual_snapshot_id}'.") | |
| else: | |
| print(f"[action] Creating manual snapshot: {manual_snapshot_id}") | |
| rds.create_db_snapshot( | |
| DBSnapshotIdentifier=manual_snapshot_id, | |
| DBInstanceIdentifier=args.db_identifier, | |
| Tags=[{"Key": "rds-rekey", "Value": ts}], | |
| ) | |
| wait_for_snapshot(rds, manual_snapshot_id) | |
| # 2) Copy with new encryption (reuse encrypted copy if already present) | |
| encrypted_copy_id = args.encrypted_snapshot_id or f"{manual_snapshot_id}-enc" | |
| if snapshot_exists(rds, encrypted_copy_id): | |
| print(f"[info] Reusing encrypted snapshot copy '{encrypted_copy_id}'.") | |
| else: | |
| print(f"[action] Copying snapshot '{manual_snapshot_id}' with KMS key: {args.target_kms_key_id}") | |
| try: | |
| rds.copy_db_snapshot( | |
| SourceDBSnapshotIdentifier=manual_snapshot_id, | |
| TargetDBSnapshotIdentifier=encrypted_copy_id, | |
| CopyTags=args.copy_tags, | |
| KmsKeyId=args.target_kms_key_id, | |
| SourceRegion=args.region, | |
| ) | |
| except ClientError as e: | |
| raise SystemExit(f"[error] Failed to copy snapshot: {e}") | |
| wait_for_snapshot(rds, encrypted_copy_id) | |
| # 3) Restore the instance | |
| new_instance_id = args.new_instance_identifier or f"{args.db_identifier}-rekeyed-{ts}" | |
| print(f"[action] Restoring new DB instance '{new_instance_id}' from snapshot '{encrypted_copy_id}'") | |
| restore_args = dict( | |
| DBInstanceIdentifier=new_instance_id, | |
| DBSnapshotIdentifier=encrypted_copy_id, | |
| CopyTagsToSnapshot=True, | |
| DeletionProtection=source_db.get("DeletionProtection", False), | |
| # Tags=[{"Key": "Name", "Value": new_instance_id}], | |
| ) | |
| restore_args.update(build_network_and_config_args_from_source(source_db)) | |
| try: | |
| rds.restore_db_instance_from_db_snapshot(**restore_args) | |
| except ClientError as e: | |
| if e.response.get("Error", {}).get("Code") == "DBInstanceAlreadyExists": | |
| print(f"[warn] Instance '{new_instance_id}' already exists; continuing to wait.") | |
| else: | |
| raise | |
| wait_for_instance_available(rds, new_instance_id) | |
| # 4) One-shot CA change: copy from source, if present, and apply immediately | |
| source_ca = source_db.get("CACertificateIdentifier") | |
| if source_ca: | |
| print(f"[action] Copying CA from source: '{source_ca}' → '{new_instance_id}' (ApplyImmediately=True)") | |
| try: | |
| rds.modify_db_instance( | |
| DBInstanceIdentifier=new_instance_id, | |
| CACertificateIdentifier=source_ca, | |
| ApplyImmediately=True, # one-shot behavior | |
| ) | |
| wait_for_instance_available(rds, new_instance_id) | |
| except ClientError as e: | |
| print(f"[warn] Could not apply source CA '{source_ca}' to '{new_instance_id}': {e}") | |
| # Continue; instance remains on its default CA. | |
| # 5) Verify + sanity checklist | |
| restored = describe_instance(rds, new_instance_id) | |
| enc = restored.get("StorageEncrypted", False) | |
| kms = restored.get("KmsKeyId") | |
| print(f"[verify] Restored instance encryption: StorageEncrypted={enc}, KmsKeyId={kms}") | |
| print_sanity(restored) | |
| print(f"[done] New instance '{new_instance_id}' created from '{encrypted_copy_id}' and uses KMS key: {kms}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print("\n[abort] Interrupted by user.") | |
| sys.exit(130) | |
| except Exception as e: | |
| print(f"[error] {e}") | |
| sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.