Last active
October 14, 2025 15:40
-
-
Save filipeandre/90245f8634a3aeb523296d11804ae8b0 to your computer and use it in GitHub Desktop.
This script acts as lock down switch for all regional aws waf for current region
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Add a top-priority "block-all" rule to every AWS WAFv2 Web ACL in the current Region (REGIONAL scope), | |
| without removing existing rules. The change is fully reversible via a local backup file. | |
| Features | |
| - Enumerates all REGIONAL Web ACLs in the configured AWS region | |
| - Creates (or reuses) IP sets that match all IPv4 and IPv6 addresses (0.0.0.0/0 and ::/0) | |
| - Inserts a new rule at priority 0 that blocks all traffic, shifting existing rule priorities down | |
| - Stores a per-WebACL backup (original rules + metadata) under ./waf_backups/<region>/<web_acl_id>.json | |
| - Revert command removes the block-all rule and restores the original rules exactly as saved in the backup | |
| - Dry run mode for both apply and revert | |
| - Optional role assumption before any API calls | |
| Usage | |
| Apply (add the rule): | |
| python aws-waf-add-block-all-rule.py apply --region eu-west-1 | |
| Revert (restore from backup): | |
| python aws-waf-add-block-all-rule.py revert --region eu-west-1 | |
| With role assumption: | |
| python aws-waf-add-block-all-rule.py apply --region eu-west-1 --role-arn arn:aws:iam::123456789012:role/Admin | |
| Dry run: | |
| python aws-waf-add-block-all-rule.py apply --region eu-west-1 --dry-run | |
| python aws-waf-add-block-all-rule.py revert --region eu-west-1 --dry-run | |
| Notes | |
| - Scope is REGIONAL only. If you need CLOUDFRONT scope, extend the script accordingly. | |
| - Backups are local files; keep them safe to ensure reversibility. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import boto3 | |
| from botocore.exceptions import ClientError | |
| BLOCK_ALL_RULE_NAME = "__TEMP_BLOCK_ALL__" | |
| IPSET_V4_NAME = "__TEMP_ALL_IPV4__" | |
| IPSET_V6_NAME = "__TEMP_ALL_IPV6__" | |
| BACKUP_DIR = "waf_backups" | |
| SCOPE = "REGIONAL" # This script targets the current region (REGIONAL) only | |
| @dataclass | |
| class AwsClients: | |
| waf: any | |
| waf_paginator: any | |
| sts: any | |
| def assume_role_if_needed(role_arn: Optional[str], session: Optional[boto3.Session] = None) -> boto3.Session: | |
| """Assume an IAM role if role_arn is provided; otherwise return the given or default session. | |
| Always called before any boto3 client usage. | |
| """ | |
| base_sess = session or boto3.Session() | |
| if not role_arn: | |
| return base_sess | |
| sts = base_sess.client("sts") | |
| resp = sts.assume_role(RoleArn=role_arn, RoleSessionName="waf-blockall-script") | |
| creds = resp["Credentials"] | |
| return boto3.Session( | |
| aws_access_key_id=creds["AccessKeyId"], | |
| aws_secret_access_key=creds["SecretAccessKey"], | |
| aws_session_token=creds["SessionToken"], | |
| region_name=base_sess.region_name, | |
| ) | |
| def make_clients(region: str, role_arn: Optional[str]) -> AwsClients: | |
| base = boto3.Session(region_name=region) | |
| sess = assume_role_if_needed(role_arn, base) | |
| waf = sess.client("wafv2", region_name=region) | |
| sts = sess.client("sts", region_name=region) | |
| paginator = waf.get_paginator("list_web_acls") | |
| return AwsClients(waf=waf, waf_paginator=paginator, sts=sts) | |
| def ensure_ipset(clients: AwsClients, name: str, description: str, addresses: List[str], region: str, scope: str = SCOPE) -> Tuple[str, str]: | |
| """Create or return an IPSet that contains the given addresses. Returns (Id, ARN).""" | |
| waf = clients.waf | |
| # Try to find existing by name | |
| next_marker = None | |
| while True: | |
| kwargs = {"Scope": scope} | |
| if next_marker: | |
| kwargs["NextMarker"] = next_marker | |
| page = waf.list_ip_sets(**kwargs) | |
| for ipset_summary in page.get("IPSets", []): | |
| if ipset_summary.get("Name") == name: | |
| # fetch details for CIDR check and return | |
| got = waf.get_ip_set(Scope=scope, Name=name, Id=ipset_summary["Id"]) | |
| is_same = sorted(got["IPSet"].get("Addresses", [])) == sorted(addresses) | |
| if not is_same: | |
| # Update it to desired addresses | |
| waf.update_ip_set( | |
| Scope=scope, | |
| Name=name, | |
| Id=ipset_summary["Id"], | |
| Addresses=addresses, | |
| LockToken=got["LockToken"], | |
| ) | |
| return ipset_summary["Id"], ipset_summary["ARN"] | |
| next_marker = page.get("NextMarker") | |
| if not next_marker: | |
| break | |
| # Create new IPSet | |
| ip_version = "IPV4" if any(a.count(":") == 0 for a in addresses) and addresses[0] != "::/0" else ("IPV6" if ":" in addresses[0] else "IPV4") | |
| resp = waf.create_ip_set( | |
| Name=name, | |
| Scope=scope, | |
| Description=description, | |
| IPAddressVersion=ip_version, | |
| Addresses=addresses, | |
| ) | |
| return resp["Summary"]["Id"], resp["Summary"]["ARN"] | |
| def list_all_web_acls(clients: AwsClients, scope: str = SCOPE) -> List[Dict]: | |
| items = [] | |
| for page in clients.waf_paginator.paginate(Scope=scope): | |
| items.extend(page.get("WebACLs", [])) | |
| return items | |
| def load_backup_path(region: str, web_acl_id: str) -> str: | |
| path = os.path.join(BACKUP_DIR, region, f"{web_acl_id}.json") | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| return path | |
| def save_backup(region: str, web_acl_id: str, data: Dict) -> None: | |
| path = load_backup_path(region, web_acl_id) | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2, sort_keys=True) | |
| def load_backup(region: str, web_acl_id: str) -> Optional[Dict]: | |
| path = load_backup_path(region, web_acl_id) | |
| if not os.path.exists(path): | |
| return None | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def build_block_all_rule(ipv4_ipset_arn: str, ipv6_ipset_arn: Optional[str]) -> Dict: | |
| statements = [] | |
| # IPv4 all | |
| statements.append({ | |
| "IPSetReferenceStatement": { | |
| "ARN": ipv4_ipset_arn | |
| } | |
| }) | |
| if ipv6_ipset_arn: | |
| statements.append({ | |
| "IPSetReferenceStatement": { | |
| "ARN": ipv6_ipset_arn | |
| } | |
| }) | |
| statement: Dict | |
| if len(statements) == 1: | |
| statement = statements[0] | |
| else: | |
| statement = {"OrStatement": {"Statements": statements}} | |
| return { | |
| "Name": BLOCK_ALL_RULE_NAME, | |
| "Priority": 0, # will be placed at the top | |
| "Statement": statement, | |
| "Action": {"Block": {}}, | |
| "VisibilityConfig": { | |
| "SampledRequestsEnabled": True, | |
| "CloudWatchMetricsEnabled": True, | |
| "MetricName": BLOCK_ALL_RULE_NAME, | |
| }, | |
| } | |
| def apply_block_all_to_acl(clients: AwsClients, region: str, web_acl_summary: Dict, dry_run: bool) -> None: | |
| waf = clients.waf | |
| name = web_acl_summary["Name"] | |
| web_acl_id = web_acl_summary["Id"] | |
| # Fetch full WebACL (need LockToken, DefaultAction, Rules, VisibilityConfig...) | |
| got = waf.get_web_acl(Scope=SCOPE, Name=name, Id=web_acl_id) | |
| web_acl = got["WebACL"] | |
| lock = got["LockToken"] | |
| # Idempotency: if the rule already exists, skip | |
| existing_rules: List[Dict] = web_acl.get("Rules", []) | |
| if any(r.get("Name") == BLOCK_ALL_RULE_NAME for r in existing_rules): | |
| print(f"[SKIP] {name} already has {BLOCK_ALL_RULE_NAME}") | |
| return | |
| # Backup original state for revert | |
| backup_payload = { | |
| "Name": name, | |
| "Id": web_acl_id, | |
| "Scope": SCOPE, | |
| "Region": region, | |
| "OriginalRules": existing_rules, | |
| "DefaultAction": web_acl.get("DefaultAction"), | |
| "VisibilityConfig": web_acl.get("VisibilityConfig"), | |
| "Description": web_acl.get("Description"), | |
| "CaptchaConfig": web_acl.get("CaptchaConfig"), | |
| "ChallengeConfig": web_acl.get("ChallengeConfig"), | |
| } | |
| save_backup(region, web_acl_id, backup_payload) | |
| # Ensure all-IP IP sets exist | |
| v4_id, v4_arn = ensure_ipset( | |
| clients, | |
| IPSET_V4_NAME, | |
| "Temporary all-IPv4 IPSet for block-all rule (0.0.0.0/0)", | |
| ["0.0.0.0/0"], | |
| region, | |
| ) | |
| v6_id, v6_arn = None, None | |
| try: | |
| v6_id, v6_arn = ensure_ipset( | |
| clients, | |
| IPSET_V6_NAME, | |
| "Temporary all-IPv6 IPSet for block-all rule (::/0)", | |
| ["::/0"], | |
| region, | |
| ) | |
| except ClientError as e: | |
| # Accounts/regions without IPv6 can ignore this | |
| if e.response.get("Error", {}).get("Code") not in {"WAFInvalidParameterException", "WAFUnavailableEntityException"}: | |
| raise | |
| print("[WARN] IPv6 IPSet creation failed or unsupported; proceeding with IPv4 only.") | |
| # Build new rule list with priorities shifted by +1 | |
| new_rules: List[Dict] = [] | |
| new_rules.append(build_block_all_rule(v4_arn, v6_arn)) | |
| # Sort existing by priority and shift | |
| for r in sorted(existing_rules, key=lambda x: x.get("Priority", 0)): | |
| r_copy = json.loads(json.dumps(r)) # deep copy | |
| r_copy["Priority"] = (r_copy.get("Priority", 0) + 1) | |
| new_rules.append(r_copy) | |
| print(f"[APPLY] {name}: will add {BLOCK_ALL_RULE_NAME} at priority 0 and shift {len(existing_rules)} rules") | |
| if dry_run: | |
| print("[DRY-RUN] Skipping update_web_acl()") | |
| return | |
| # Update the WebACL | |
| update_kwargs = { | |
| "Name": name, | |
| "Scope": SCOPE, | |
| "Id": web_acl_id, | |
| "DefaultAction": web_acl["DefaultAction"], | |
| "Description": web_acl.get("Description", ""), | |
| "Rules": new_rules, | |
| "VisibilityConfig": web_acl["VisibilityConfig"], | |
| "LockToken": lock, | |
| } | |
| # Preserve optional configs when present | |
| if "CaptchaConfig" in web_acl: | |
| update_kwargs["CaptchaConfig"] = web_acl["CaptchaConfig"] | |
| if "ChallengeConfig" in web_acl: | |
| update_kwargs["ChallengeConfig"] = web_acl["ChallengeConfig"] | |
| waf.update_web_acl(**update_kwargs) | |
| print(f"[DONE] {name}: block-all rule added") | |
| def revert_acl_from_backup(clients: AwsClients, region: str, web_acl_summary: Dict, dry_run: bool) -> None: | |
| waf = clients.waf | |
| name = web_acl_summary["Name"] | |
| web_acl_id = web_acl_summary["Id"] | |
| # Load backup | |
| backup = load_backup(region, web_acl_id) | |
| if not backup: | |
| print(f"[SKIP] No backup found for {name} ({web_acl_id}), nothing to revert.") | |
| return | |
| got = waf.get_web_acl(Scope=SCOPE, Name=name, Id=web_acl_id) | |
| web_acl = got["WebACL"] | |
| lock = got["LockToken"] | |
| # Prepare restore rules (exactly as saved) | |
| original_rules = backup.get("OriginalRules", []) | |
| print(f"[REVERT] {name}: will restore {len(original_rules)} original rules and remove {BLOCK_ALL_RULE_NAME} if present") | |
| if dry_run: | |
| print("[DRY-RUN] Skipping update_web_acl()") | |
| return | |
| update_kwargs = { | |
| "Name": name, | |
| "Scope": SCOPE, | |
| "Id": web_acl_id, | |
| "DefaultAction": backup.get("DefaultAction", web_acl.get("DefaultAction")), | |
| "Description": backup.get("Description", web_acl.get("Description", "")), | |
| "Rules": original_rules, | |
| "VisibilityConfig": backup.get("VisibilityConfig", web_acl.get("VisibilityConfig")), | |
| "LockToken": lock, | |
| } | |
| if backup.get("CaptchaConfig"): | |
| update_kwargs["CaptchaConfig"] = backup["CaptchaConfig"] | |
| if backup.get("ChallengeConfig"): | |
| update_kwargs["ChallengeConfig"] = backup["ChallengeConfig"] | |
| waf.update_web_acl(**update_kwargs) | |
| print(f"[DONE] {name}: restored original rule set") | |
| def cmd_apply(clients: AwsClients, region: str, dry_run: bool) -> int: | |
| count = 0 | |
| for acl in list_all_web_acls(clients): | |
| try: | |
| apply_block_all_to_acl(clients, region, acl, dry_run) | |
| count += 1 | |
| except ClientError as e: | |
| print(f"[ERROR] {acl['Name']}: {e}") | |
| print(f"[SUMMARY] Processed {count} Web ACL(s) in {region} (apply)") | |
| return 0 | |
| def cmd_revert(clients: AwsClients, region: str, dry_run: bool) -> int: | |
| count = 0 | |
| for acl in list_all_web_acls(clients): | |
| try: | |
| revert_acl_from_backup(clients, region, acl, dry_run) | |
| count += 1 | |
| except ClientError as e: | |
| print(f"[ERROR] {acl['Name']}: {e}") | |
| print(f"[SUMMARY] Processed {count} Web ACL(s) in {region} (revert)") | |
| return 0 | |
| def parse_args(argv: List[str]) -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Add (and revert) a temporary top-priority block-all rule to all WAFv2 REGIONAL Web ACLs.") | |
| sub = p.add_subparsers(dest="command", required=True) | |
| common = argparse.ArgumentParser(add_help=False) | |
| common.add_argument("--region", required=True, help="AWS region, e.g. eu-west-1") | |
| common.add_argument("--role-arn", default=None, help="Optional role ARN to assume before making any API calls") | |
| common.add_argument("--dry-run", action="store_true", help="Do not perform updates; print planned changes only") | |
| sub_apply = sub.add_parser("apply", parents=[common], help="Add the block-all rule at the top of every Web ACL") | |
| sub_revert = sub.add_parser("revert", parents=[common], help="Revert using local backups (restores original rules)") | |
| return p.parse_args(argv) | |
| def main(argv: List[str]) -> int: | |
| args = parse_args(argv) | |
| clients = make_clients(args.region, args.role_arn) | |
| if args.command == "apply": | |
| return cmd_apply(clients, args.region, args.dry_run) | |
| elif args.command == "revert": | |
| return cmd_revert(clients, args.region, args.dry_run) | |
| else: | |
| print("Unknown command") | |
| return 2 | |
| if __name__ == "__main__": | |
| sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment