Skip to content

Instantly share code, notes, and snippets.

@filipeandre
Last active October 14, 2025 15:53
Show Gist options
  • Save filipeandre/f4a3a9bbb5b927e1dd9d8466d1c149ac to your computer and use it in GitHub Desktop.
Save filipeandre/f4a3a9bbb5b927e1dd9d8466d1c149ac to your computer and use it in GitHub Desktop.
Re-encrypt an Amazon RDS *instance* by snapshot→copy(with KMS)→restore.
#!/usr/bin/env python3
"""
Re-encrypt an Amazon RDS *instance* by snapshot → copy(with KMS) → restore.
Usage:
python rds_rekey_instance.py \
--db-identifier my-db \
--target-kms-key-id arn:aws:kms:us-east-1:123456789012:key/abcd-... \
--region us-east-1 \
[--source-snapshot-id my-existing-snapshot] \
[--encrypted-snapshot-id my-custom-snap-enc] \
[--new-instance-identifier my-db-rekeyed] \
[--role-arn arn:aws:iam::123456789012:role/Admin] \
[--copy-tags]
"""
import argparse
import datetime as dt
import sys
import time
from typing import Dict, List, Optional
import boto3
from botocore.exceptions import ClientError
# ---------- Utility ----------
def get_boto3_session(region: str, role_arn: Optional[str] = None) -> boto3.Session:
if role_arn:
sts = boto3.client("sts", region_name=region)
creds = sts.assume_role(
RoleArn=role_arn,
RoleSessionName=f"rds-rekey-{int(time.time())}"
)["Credentials"]
return boto3.Session(
aws_access_key_id=creds["AccessKeyId"],
aws_secret_access_key=creds["SecretAccessKey"],
aws_session_token=creds["SessionToken"],
region_name=region,
)
return boto3.Session(region_name=region)
def now_suffix() -> str:
return dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
def snapshot_exists(rds, snapshot_id: str) -> Optional[Dict]:
"""Return snapshot dict if it exists, else None."""
try:
resp = rds.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
snaps = resp.get("DBSnapshots", [])
return snaps[0] if snaps else None
except ClientError as e:
code = e.response.get("Error", {}).get("Code")
if code in ("DBSnapshotNotFound", "InvalidParameterValue"):
return None
raise
def wait_for_snapshot(rds, snapshot_id: str, poll_sec: int = 30):
waiter = rds.get_waiter("db_snapshot_available")
print(f"[wait] Waiting for snapshot '{snapshot_id}' to become available...")
waiter.wait(DBSnapshotIdentifier=snapshot_id, WaiterConfig={"Delay": poll_sec, "MaxAttempts": 240})
print(f"[ok] Snapshot '{snapshot_id}' is available.")
def wait_for_instance_available(rds, db_instance_id: str, poll_sec: int = 30):
waiter = rds.get_waiter("db_instance_available")
print(f"[wait] Waiting for DB instance '{db_instance_id}' to become available...")
waiter.wait(DBInstanceIdentifier=db_instance_id, WaiterConfig={"Delay": poll_sec, "MaxAttempts": 240})
print(f"[ok] DB instance '{db_instance_id}' is available.")
def describe_instance(rds, db_id: str) -> Dict:
resp = rds.describe_db_instances(DBInstanceIdentifier=db_id)
return resp["DBInstances"][0]
def ensure_not_aurora(db: Dict):
engine = db["Engine"].lower()
if engine.startswith("aurora"):
raise SystemExit(
f"[error] '{db['DBInstanceIdentifier']}' appears to be an Aurora/cluster engine ('{engine}'). "
"Use the *cluster* snapshot APIs (CopyDBClusterSnapshot / RestoreDBClusterFromSnapshot) for Aurora."
)
def build_network_and_config_args_from_source(db: Dict) -> Dict:
"""
Build restore args to align the new instance with the source's class, networking,
parameter groups, etc. Only fields supported by restore_db_instance_from_db_snapshot.
"""
args: Dict = {}
# Instance class / sizing
if "DBInstanceClass" in db:
args["DBInstanceClass"] = db["DBInstanceClass"]
# Networking
if db.get("DBSubnetGroup", {}).get("DBSubnetGroupName"):
args["DBSubnetGroupName"] = db["DBSubnetGroup"]["DBSubnetGroupName"]
vpc_sg_ids: List[str] = [
sg["VpcSecurityGroupId"]
for sg in db.get("VpcSecurityGroups", [])
if sg.get("VpcSecurityGroupId")
]
if vpc_sg_ids:
args["VpcSecurityGroupIds"] = vpc_sg_ids
if "PubliclyAccessible" in db:
args["PubliclyAccessible"] = db["PubliclyAccessible"]
# Port (optional; often implied by engine but allowed)
if db.get("Endpoint", {}).get("Port"):
args["Port"] = db["Endpoint"]["Port"]
# Storage class / IOPS (where applicable)
if "StorageType" in db:
args["StorageType"] = db["StorageType"]
if db.get("Iops"):
args["Iops"] = db["Iops"]
# Multi-AZ preference
if "MultiAZ" in db:
args["MultiAZ"] = db["MultiAZ"]
# Parameter & option groups
if db.get("DBParameterGroups"):
args["DBParameterGroupName"] = db["DBParameterGroups"][0]["DBParameterGroupName"]
# Option groups are usually not used by PostgreSQL; include only if present/healthy.
if db.get("OptionGroupMemberships"):
ogs = [
og["OptionGroupName"]
for og in db["OptionGroupMemberships"]
if og.get("Status") in (None, "in-sync", "pending-apply")
]
if ogs:
args["OptionGroupName"] = ogs[0]
# CloudWatch logs exports (if present on source)
if db.get("EnabledCloudwatchLogsExports"):
args["EnableCloudwatchLogsExports"] = db["EnabledCloudwatchLogsExports"]
# Deletion protection, auto minor version upgrade (if present)
if "AutoMinorVersionUpgrade" in db:
args["AutoMinorVersionUpgrade"] = db["AutoMinorVersionUpgrade"]
# IAM DB Auth (if enabled on source)
if "IAMDatabaseAuthenticationEnabled" in db:
args["EnableIAMDatabaseAuthentication"] = db["IAMDatabaseAuthenticationEnabled"]
return args
def print_sanity(restored: Dict):
"""Emit a concise checklist for manual verification."""
ident = restored["DBInstanceIdentifier"]
engine = f"{restored.get('Engine')} {restored.get('EngineVersion')}"
subnet = restored.get("DBSubnetGroup", {}).get("DBSubnetGroupName")
sgs = [sg.get("VpcSecurityGroupId") for sg in restored.get("VpcSecurityGroups", [])]
port = restored.get("Endpoint", {}).get("Port")
pg = (restored.get("DBParameterGroups") or [{}])[0].get("DBParameterGroupName")
ca = restored.get("CACertificateIdentifier")
kms = restored.get("KmsKeyId")
enc = restored.get("StorageEncrypted")
publicly = restored.get("PubliclyAccessible")
multi_az = restored.get("MultiAZ")
print("\n[sanity]")
print(f" InstanceId............... {ident}")
print(f" Engine/Version........... {engine}")
print(f" SubnetGroup.............. {subnet}")
print(f" VPC SGs.................. {', '.join(filter(None, sgs)) or '-'}")
print(f" Port..................... {port}")
print(f" Param Group.............. {pg}")
print(f" CA Cert.................. {ca or '-'}")
print(f" Encrypted................ {enc} (KMS={kms})")
print(f" PubliclyAccessible....... {publicly}")
print(f" MultiAZ.................. {multi_az}\n")
# ---------- Core workflow ----------
def main():
parser = argparse.ArgumentParser(description="Re-encrypt an RDS DB instance by snapshot → copy(with KMS) → restore.")
parser.add_argument("--db-identifier", required=True, help="Source DB instance identifier.")
parser.add_argument("--target-kms-key-id", required=True, help="KMS Key ARN or ID for the *copied* snapshot.")
parser.add_argument("--region", required=True, help="AWS region, e.g., eu-west-1.")
parser.add_argument("--new-instance-identifier", help="Identifier for the restored instance.")
parser.add_argument("--role-arn", help="Optional IAM role to assume.")
parser.add_argument("--source-snapshot-id", help="Use existing snapshot ID instead of creating a new one.")
parser.add_argument("--manual-snapshot-id", help="Custom ID for the manual snapshot to create/reuse.")
parser.add_argument("--encrypted-snapshot-id", help="Custom ID for the encrypted snapshot copy to create/reuse.")
parser.add_argument("--copy-tags", action="store_true", help="Copy tags to the encrypted snapshot copy.")
args = parser.parse_args()
session = get_boto3_session(region=args.region, role_arn=args.role_arn)
rds = session.client("rds")
# Get DB details
print(f"[info] Describing source DB instance '{args.db_identifier}'...")
try:
source_db = describe_instance(rds, args.db_identifier)
except ClientError as e:
raise SystemExit(f"[error] Could not describe DB instance '{args.db_identifier}': {e}")
ensure_not_aurora(source_db)
ts = now_suffix()
# 1) Decide the manual snapshot to use/create
if args.source_snapshot_id:
manual_snapshot_id = args.source_snapshot_id
print(f"[info] Using existing snapshot '{manual_snapshot_id}'.")
else:
# Prefer user-provided name, else default pattern
manual_snapshot_id = args.manual_snapshot_id or f"{args.db_identifier}-manual-{ts}"
if snapshot_exists(rds, manual_snapshot_id):
print(f"[info] Reusing pre-existing snapshot '{manual_snapshot_id}'.")
else:
print(f"[action] Creating manual snapshot: {manual_snapshot_id}")
rds.create_db_snapshot(
DBSnapshotIdentifier=manual_snapshot_id,
DBInstanceIdentifier=args.db_identifier,
Tags=[{"Key": "rds-rekey", "Value": ts}],
)
wait_for_snapshot(rds, manual_snapshot_id)
# 2) Copy with new encryption (reuse encrypted copy if already present)
encrypted_copy_id = args.encrypted_snapshot_id or f"{manual_snapshot_id}-enc"
if snapshot_exists(rds, encrypted_copy_id):
print(f"[info] Reusing encrypted snapshot copy '{encrypted_copy_id}'.")
else:
print(f"[action] Copying snapshot '{manual_snapshot_id}' with KMS key: {args.target_kms_key_id}")
try:
rds.copy_db_snapshot(
SourceDBSnapshotIdentifier=manual_snapshot_id,
TargetDBSnapshotIdentifier=encrypted_copy_id,
CopyTags=args.copy_tags,
KmsKeyId=args.target_kms_key_id,
SourceRegion=args.region,
)
except ClientError as e:
raise SystemExit(f"[error] Failed to copy snapshot: {e}")
wait_for_snapshot(rds, encrypted_copy_id)
# 3) Restore the instance
new_instance_id = args.new_instance_identifier or f"{args.db_identifier}-rekeyed-{ts}"
print(f"[action] Restoring new DB instance '{new_instance_id}' from snapshot '{encrypted_copy_id}'")
restore_args = dict(
DBInstanceIdentifier=new_instance_id,
DBSnapshotIdentifier=encrypted_copy_id,
CopyTagsToSnapshot=True,
DeletionProtection=source_db.get("DeletionProtection", False),
# Tags=[{"Key": "Name", "Value": new_instance_id}],
)
restore_args.update(build_network_and_config_args_from_source(source_db))
try:
rds.restore_db_instance_from_db_snapshot(**restore_args)
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "DBInstanceAlreadyExists":
print(f"[warn] Instance '{new_instance_id}' already exists; continuing to wait.")
else:
raise
wait_for_instance_available(rds, new_instance_id)
# 4) One-shot CA change: copy from source, if present, and apply immediately
source_ca = source_db.get("CACertificateIdentifier")
if source_ca:
print(f"[action] Copying CA from source: '{source_ca}' → '{new_instance_id}' (ApplyImmediately=True)")
try:
rds.modify_db_instance(
DBInstanceIdentifier=new_instance_id,
CACertificateIdentifier=source_ca,
ApplyImmediately=True, # one-shot behavior
)
wait_for_instance_available(rds, new_instance_id)
except ClientError as e:
print(f"[warn] Could not apply source CA '{source_ca}' to '{new_instance_id}': {e}")
# Continue; instance remains on its default CA.
# 5) Verify + sanity checklist
restored = describe_instance(rds, new_instance_id)
enc = restored.get("StorageEncrypted", False)
kms = restored.get("KmsKeyId")
print(f"[verify] Restored instance encryption: StorageEncrypted={enc}, KmsKeyId={kms}")
print_sanity(restored)
print(f"[done] New instance '{new_instance_id}' created from '{encrypted_copy_id}' and uses KMS key: {kms}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n[abort] Interrupted by user.")
sys.exit(130)
except Exception as e:
print(f"[error] {e}")
sys.exit(1)
@filipeandre
Copy link
Author

filipeandre commented Oct 14, 2025

# clone or update
REPO_FOLDER=rds_tools
REPO_URL=https://gist.github.com/filipeandre/f4a3a9bbb5b927e1dd9d8466d1c149ac.git
[ -d .git ] && git pull --rebase || { [ -d $REPO_FOLDER/.git ] && git -C $REPO_FOLDER pull --rebase || git clone $REPO_URL $REPO_FOLDER; cd $REPO_FOLDER 2>/dev/null || true; }

# execute script
python3 rds_rekey_instance.py \
  --db-identifier my-db \
  --target-kms-key-id arn:aws:kms:eu-west-1:123456789012:key/abcd-... \
  --region eu-west-1 \
  --new-instance-identifier my-db-rekeyed \
  --source-snapshot-id my-existing-snapshot \
  --copy-tags

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment